Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import subprocess | |
| import shutil | |
| import tempfile | |
| import textwrap | |
| from typing import Any, Optional | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State, Observation | |
| try: | |
| from ..models import TcgeneratorAction, TcgeneratorObservation | |
| except ImportError: | |
| from models import TcgeneratorAction, TcgeneratorObservation | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # TASKS β Java source code | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| TASKS = { | |
| "easy": { | |
| "source_code": textwrap.dedent(""" | |
| public class Calculator { | |
| public int add(int a, int b) { | |
| return a + b; | |
| } | |
| public boolean isPalindrome(String s) { | |
| String clean = s.toLowerCase().replace(" ", ""); | |
| String reversed = new StringBuilder(clean).reverse().toString(); | |
| return clean.equals(reversed); | |
| } | |
| public double celsiusToFahrenheit(double c) { | |
| return c * 9.0 / 5.0 + 32; | |
| } | |
| } | |
| """).strip(), | |
| "source_class": "Calculator", | |
| "test_class": "CalculatorTest", | |
| "task_hint": ( | |
| "Write JUnit 5 tests for Calculator class. " | |
| "1) Class name MUST be CalculatorTest. " | |
| "2) Use @Test annotation on every test method. " | |
| "3) Import: import org.junit.jupiter.api.Test; " | |
| "4) Import: import static org.junit.jupiter.api.Assertions.*; " | |
| "5) Write at least 6 test methods. " | |
| "6) Test add(), isPalindrome(), celsiusToFahrenheit(). " | |
| "Reply with ONLY the Java code." | |
| ), | |
| "expected_min_tests": 6, | |
| "requires_edge_cases": False, | |
| "requires_exception": False, | |
| }, | |
| "medium": { | |
| "source_code": textwrap.dedent(""" | |
| public class SafeMath { | |
| public double safeDivide(double a, double b) { | |
| if (b == 0) | |
| throw new IllegalArgumentException("Cannot divide by zero"); | |
| return a / b; | |
| } | |
| public int getFirstElement(int[] arr) { | |
| if (arr.length == 0) | |
| throw new IndexOutOfBoundsException("Array is empty"); | |
| return arr[0]; | |
| } | |
| public int parsePositiveInt(String s) { | |
| int val = Integer.parseInt(s); | |
| if (val <= 0) | |
| throw new IllegalArgumentException("Must be positive"); | |
| return val; | |
| } | |
| } | |
| """).strip(), | |
| "source_class": "SafeMath", | |
| "test_class": "SafeMathTest", | |
| "task_hint": ( | |
| "Write JUnit 5 tests for SafeMath class. " | |
| "1) Class name MUST be SafeMathTest. " | |
| "2) Use @Test annotation on every test method. " | |
| "3) Import: import org.junit.jupiter.api.Test; " | |
| "4) Import: import static org.junit.jupiter.api.Assertions.*; " | |
| "5) Use assertThrows() for exception testing. " | |
| "6) Test edge cases: empty array, zero, negative. " | |
| "7) Write at least 8 test methods. " | |
| "Reply with ONLY the Java code." | |
| ), | |
| "expected_min_tests": 8, | |
| "requires_edge_cases": True, | |
| "requires_exception": True, | |
| }, | |
| "hard": { | |
| "source_code": textwrap.dedent(""" | |
| import java.util.ArrayList; | |
| import java.util.List; | |
| public class BankAccount { | |
| private String owner; | |
| private double balance; | |
| private List<String> transactions; | |
| public BankAccount(String owner, double balance) { | |
| this.owner = owner; | |
| this.balance = balance; | |
| this.transactions = new ArrayList<>(); | |
| } | |
| public double deposit(double amount) { | |
| if (amount <= 0) | |
| throw new IllegalArgumentException("Deposit must be positive"); | |
| this.balance += amount; | |
| this.transactions.add("deposit:" + amount); | |
| return this.balance; | |
| } | |
| public double withdraw(double amount) { | |
| if (amount <= 0) | |
| throw new IllegalArgumentException("Withdrawal must be positive"); | |
| if (amount > this.balance) | |
| throw new IllegalArgumentException("Insufficient funds"); | |
| this.balance -= amount; | |
| this.transactions.add("withdraw:" + amount); | |
| return this.balance; | |
| } | |
| public int getTransactionCount() { return this.transactions.size(); } | |
| public double getBalance() { return this.balance; } | |
| public String getOwner() { return this.owner; } | |
| } | |
| """).strip(), | |
| "source_class": "BankAccount", | |
| "test_class": "BankAccountTest", | |
| "task_hint": ( | |
| "Write JUnit 5 tests for BankAccount class. " | |
| "1) Class name MUST be BankAccountTest. " | |
| "2) Use @Test annotation on every test method. " | |
| "3) Import: import org.junit.jupiter.api.Test; " | |
| "4) Import: import static org.junit.jupiter.api.Assertions.*; " | |
| "5) Use assertThrows() for exception testing. " | |
| "6) Test deposit, withdraw, getBalance, getTransactionCount, getOwner. " | |
| "7) Test exceptions: negative deposit, withdraw, insufficient funds. " | |
| "8) Write at least 10 test methods. " | |
| "Reply with ONLY the Java code." | |
| ), | |
| "expected_min_tests": 10, | |
| "requires_edge_cases": True, | |
| "requires_exception": True, | |
| }, | |
| } | |
| DIFFICULTIES = ("easy", "medium", "hard") | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # JAVA DETECTION | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def find_java(): | |
| for cmd in ["java", "/usr/bin/java", "/usr/local/bin/java"]: | |
| try: | |
| result = subprocess.run([cmd, "-version"], capture_output=True, text=True) | |
| if result.returncode == 0: | |
| print(f"[DEBUG] Java found: {cmd}", flush=True) | |
| return cmd | |
| except FileNotFoundError: | |
| continue | |
| return None | |
| def find_javac(): | |
| for cmd in ["javac", "/usr/bin/javac", "/usr/local/bin/javac"]: | |
| try: | |
| result = subprocess.run([cmd, "-version"], capture_output=True, text=True) | |
| if result.returncode == 0: | |
| print(f"[DEBUG] javac found: {cmd}", flush=True) | |
| return cmd | |
| except FileNotFoundError: | |
| continue | |
| return None | |
| def find_junit_jar(): | |
| candidates = [ | |
| "/app/junit-platform-console-standalone.jar", | |
| "/app/env/junit-platform-console-standalone.jar", | |
| os.path.join(os.path.dirname(__file__), "..", "junit-platform-console-standalone.jar"), | |
| "/Users/vidhikoul/Desktop/UnitTestCaseGenerator/tcgenerator/junit-platform-console-standalone.jar", | |
| ] | |
| for path in candidates: | |
| if os.path.exists(path): | |
| print(f"[DEBUG] JUnit jar found: {path}", flush=True) | |
| return path | |
| return None | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # JUNIT SANDBOX | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_junit_tests(source_code: str, test_code: str, | |
| source_class: str, test_class: str, | |
| timeout: int = 30): | |
| java = find_java() | |
| javac = find_javac() | |
| junit = find_junit_jar() | |
| if not java or not javac: | |
| return 0, 0, 0, "Java not found!" | |
| if not junit: | |
| return 0, 0, 0, "JUnit jar not found!" | |
| tmpdir = tempfile.mkdtemp() | |
| try: | |
| source_file = os.path.join(tmpdir, f"{source_class}.java") | |
| test_file = os.path.join(tmpdir, f"{test_class}.java") | |
| with open(source_file, "w") as f: | |
| f.write(source_code) | |
| with open(test_file, "w") as f: | |
| f.write(test_code) | |
| compile_result = subprocess.run( | |
| [javac, "-cp", junit, source_file, test_file], | |
| capture_output=True, text=True, | |
| cwd=tmpdir, timeout=30, | |
| ) | |
| if compile_result.returncode != 0: | |
| err = compile_result.stderr[:500] | |
| print(f"[DEBUG] Compile error: {err}", flush=True) | |
| return 0, 0, 0, f"Compile error: {err}" | |
| run_result = subprocess.run( | |
| [java, "-jar", junit, "-cp", tmpdir, | |
| "--select-class", test_class, "--details", "summary"], | |
| capture_output=True, text=True, | |
| cwd=tmpdir, timeout=timeout, | |
| ) | |
| output = run_result.stdout + run_result.stderr | |
| print(f"[DEBUG] JUnit output:\n{output}", flush=True) | |
| passed, failed = parse_junit_output(output) | |
| total = passed + failed | |
| error = None if passed > 0 else output[-400:] | |
| return passed, failed, total, error | |
| except subprocess.TimeoutExpired: | |
| return 0, 0, 0, "Timeout" | |
| except Exception as e: | |
| return 0, 0, 0, str(e) | |
| finally: | |
| shutil.rmtree(tmpdir, ignore_errors=True) | |
| def parse_junit_output(output: str): | |
| passed = failed = 0 | |
| p = re.findall(r'(\d+)\s+tests?\s+successful', output, re.IGNORECASE) | |
| if p: | |
| passed = int(p[0]) | |
| f = re.findall(r'(\d+)\s+tests?\s+failed', output, re.IGNORECASE) | |
| if f: | |
| failed = int(f[0]) | |
| if passed == 0 and failed == 0: | |
| passed = len(re.findall(r'\[\s*OK\s*\]', output)) | |
| failed = len(re.findall(r'\[\s*FAILED\s*\]', output)) | |
| return passed, failed | |
| def compute_reward(passed, total, test_code, task_cfg): | |
| if total == 0: | |
| return 0.0 | |
| base = (passed / total) * 0.7 | |
| quantity_bonus = 0.1 if total >= task_cfg["expected_min_tests"] else 0.0 | |
| edge_bonus = 0.0 | |
| if task_cfg["requires_edge_cases"]: | |
| edge_keywords = ["empty", "zero", "null", "negative", "0", "[]", '""'] | |
| if any(kw in test_code.lower() for kw in edge_keywords): | |
| edge_bonus = 0.1 | |
| exception_bonus = 0.0 | |
| if task_cfg["requires_exception"]: | |
| if "assertthrows" in test_code.lower(): | |
| exception_bonus = 0.1 | |
| return round(min(base + quantity_bonus + edge_bonus + exception_bonus, 1.0), 3) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # ENVIRONMENT CLASS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class TcgeneratorEnvironment(Environment): | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._episode_count = 0 | |
| self._step_count = 0 | |
| self._max_steps = 6 | |
| self._best_reward = 0.0 | |
| self._difficulty = "easy" | |
| self._task_cfg = TASKS["easy"] | |
| def reset( | |
| self, | |
| difficulty: Optional[str] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> TcgeneratorObservation: | |
| """ | |
| Reset environment with a difficulty level. | |
| Args: | |
| difficulty: "easy" | "medium" | "hard" | |
| Cycles easyβmediumβhard if not provided. | |
| episode_id: Optional custom episode ID. | |
| Returns: | |
| TcgeneratorObservation with source_code and task_hint. | |
| """ | |
| # Difficulty decide karo | |
| if difficulty not in DIFFICULTIES: | |
| # Auto cycle: easy β medium β hard β easy β ... | |
| difficulty = DIFFICULTIES[self._episode_count % 3] | |
| self._episode_count += 1 | |
| self._difficulty = difficulty | |
| self._task_cfg = TASKS[difficulty] | |
| self._step_count = 0 | |
| self._best_reward = 0.0 | |
| self._state = State( | |
| episode_id=episode_id or str(uuid4()), | |
| step_count=0 | |
| ) | |
| print(f"[DEBUG] Reset called β difficulty: {difficulty}", flush=True) | |
| return TcgeneratorObservation( | |
| source_code = self._task_cfg["source_code"], | |
| task_hint = self._task_cfg["task_hint"], | |
| passed=0, failed=0, total=0, | |
| error=None, reward=0.0, done=False, | |
| metadata={ | |
| "difficulty": difficulty, | |
| "episode_id": self._state.episode_id, | |
| "message": ( | |
| f"New {difficulty} task loaded. " | |
| f"Write JUnit 5 tests for {self._task_cfg['source_class']} class." | |
| ), | |
| } | |
| ) | |
| def step(self, action: TcgeneratorAction) -> TcgeneratorObservation: | |
| self._state.step_count += 1 | |
| self._step_count += 1 | |
| done = self._step_count >= self._max_steps | |
| passed, failed, total, error = run_junit_tests( | |
| source_code = self._task_cfg["source_code"], | |
| test_code = action.test_code, | |
| source_class = self._task_cfg["source_class"], | |
| test_class = self._task_cfg["test_class"], | |
| ) | |
| reward = compute_reward(passed, total, action.test_code, self._task_cfg) | |
| self._best_reward = max(self._best_reward, reward) | |
| if reward >= 0.95: | |
| done = True | |
| return TcgeneratorObservation( | |
| source_code = self._task_cfg["source_code"], | |
| task_hint = self._task_cfg["task_hint"], | |
| passed=passed, failed=failed, total=total, | |
| error=error, reward=reward, done=done, | |
| ) | |
| def state(self) -> State: | |
| return self._state | |
| # import os | |
| # import re | |
| # import subprocess | |
| # import shutil | |
| # import tempfile | |
| # import textwrap | |
| # from typing import Optional | |
| # from uuid import uuid4 | |
| # from openenv.core.env_server.interfaces import Environment | |
| # from openenv.core.env_server.types import State | |
| # try: | |
| # import sys | |
| # sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| # from models import TcgeneratorAction, TcgeneratorObservation | |
| # except ImportError: | |
| # from models import TcgeneratorAction, TcgeneratorObservation | |
| # # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # # TASKS β Java source code | |
| # # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # TASKS = { | |
| # "test-easy": { | |
| # "source_code": textwrap.dedent(""" | |
| # public class Calculator { | |
| # public int add(int a, int b) { | |
| # return a + b; | |
| # } | |
| # public boolean isPalindrome(String s) { | |
| # String clean = s.toLowerCase().replace(" ", ""); | |
| # String reversed = new StringBuilder(clean).reverse().toString(); | |
| # return clean.equals(reversed); | |
| # } | |
| # public double celsiusToFahrenheit(double c) { | |
| # return c * 9.0 / 5.0 + 32; | |
| # } | |
| # } | |
| # """).strip(), | |
| # "source_class": "Calculator", | |
| # "test_class": "CalculatorTest", | |
| # "task_hint": ( | |
| # "Write JUnit 5 tests for Calculator class. " | |
| # "Rules: " | |
| # "1) Class name MUST be CalculatorTest. " | |
| # "2) Use @Test annotation on every test method. " | |
| # "3) Import: import org.junit.jupiter.api.Test; " | |
| # "4) Import: import static org.junit.jupiter.api.Assertions.*; " | |
| # "5) Write at least 6 test methods. " | |
| # "6) Test add(), isPalindrome(), celsiusToFahrenheit(). " | |
| # "Reply with ONLY the Java code." | |
| # ), | |
| # "expected_min_tests": 6, | |
| # "requires_edge_cases": False, | |
| # "requires_exception": False, | |
| # }, | |
| # "test-medium": { | |
| # "source_code": textwrap.dedent(""" | |
| # public class SafeMath { | |
| # public double safeDivide(double a, double b) { | |
| # if (b == 0) | |
| # throw new IllegalArgumentException("Cannot divide by zero"); | |
| # return a / b; | |
| # } | |
| # public int getFirstElement(int[] arr) { | |
| # if (arr.length == 0) | |
| # throw new IndexOutOfBoundsException("Array is empty"); | |
| # return arr[0]; | |
| # } | |
| # public int parsePositiveInt(String s) { | |
| # int val = Integer.parseInt(s); | |
| # if (val <= 0) | |
| # throw new IllegalArgumentException("Must be positive"); | |
| # return val; | |
| # } | |
| # } | |
| # """).strip(), | |
| # "source_class": "SafeMath", | |
| # "test_class": "SafeMathTest", | |
| # "task_hint": ( | |
| # "Write JUnit 5 tests for SafeMath class. " | |
| # "Rules: " | |
| # "1) Class name MUST be SafeMathTest. " | |
| # "2) Use @Test annotation on every test method. " | |
| # "3) Import: import org.junit.jupiter.api.Test; " | |
| # "4) Import: import static org.junit.jupiter.api.Assertions.*; " | |
| # "5) Use assertThrows() for exception testing. " | |
| # "6) Test edge cases: empty array, zero, negative numbers. " | |
| # "7) Write at least 8 test methods. " | |
| # "Reply with ONLY the Java code." | |
| # ), | |
| # "expected_min_tests": 8, | |
| # "requires_edge_cases": True, | |
| # "requires_exception": True, | |
| # }, | |
| # "test-hard": { | |
| # "source_code": textwrap.dedent(""" | |
| # import java.util.ArrayList; | |
| # import java.util.List; | |
| # public class BankAccount { | |
| # private String owner; | |
| # private double balance; | |
| # private List<String> transactions; | |
| # public BankAccount(String owner, double balance) { | |
| # this.owner = owner; | |
| # this.balance = balance; | |
| # this.transactions = new ArrayList<>(); | |
| # } | |
| # public double deposit(double amount) { | |
| # if (amount <= 0) | |
| # throw new IllegalArgumentException("Deposit must be positive"); | |
| # this.balance += amount; | |
| # this.transactions.add("deposit:" + amount); | |
| # return this.balance; | |
| # } | |
| # public double withdraw(double amount) { | |
| # if (amount <= 0) | |
| # throw new IllegalArgumentException("Withdrawal must be positive"); | |
| # if (amount > this.balance) | |
| # throw new IllegalArgumentException("Insufficient funds"); | |
| # this.balance -= amount; | |
| # this.transactions.add("withdraw:" + amount); | |
| # return this.balance; | |
| # } | |
| # public int getTransactionCount() { | |
| # return this.transactions.size(); | |
| # } | |
| # public double getBalance() { | |
| # return this.balance; | |
| # } | |
| # public String getOwner() { | |
| # return this.owner; | |
| # } | |
| # } | |
| # """).strip(), | |
| # "source_class": "BankAccount", | |
| # "test_class": "BankAccountTest", | |
| # "task_hint": ( | |
| # "Write JUnit 5 tests for BankAccount class. " | |
| # "Rules: " | |
| # "1) Class name MUST be BankAccountTest. " | |
| # "2) Use @Test annotation on every test method. " | |
| # "3) Import: import org.junit.jupiter.api.Test; " | |
| # "4) Import: import static org.junit.jupiter.api.Assertions.*; " | |
| # "5) Use assertThrows() for exception testing. " | |
| # "6) Test deposit, withdraw, getBalance, getTransactionCount, getOwner. " | |
| # "7) Test exceptions: negative deposit, negative withdraw, insufficient funds. " | |
| # "8) Write at least 10 test methods. " | |
| # "Reply with ONLY the Java code." | |
| # ), | |
| # "expected_min_tests": 10, | |
| # "requires_edge_cases": True, | |
| # "requires_exception": True, | |
| # }, | |
| # } | |
| # # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # # JAVA DETECTION | |
| # # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # def find_java(): | |
| # """System mein java aur javac dhundho""" | |
| # for cmd in ["java", "/usr/bin/java", "/usr/local/bin/java"]: | |
| # try: | |
| # result = subprocess.run( | |
| # [cmd, "--version"], | |
| # capture_output=True, text=True | |
| # ) | |
| # if result.returncode == 0: | |
| # print(f"[DEBUG] Java found: {cmd}", flush=True) | |
| # return cmd | |
| # except FileNotFoundError: | |
| # continue | |
| # return None | |
| # def find_javac(): | |
| # """javac compiler dhundho""" | |
| # for cmd in ["javac", "/usr/bin/javac", "/usr/local/bin/javac"]: | |
| # try: | |
| # result = subprocess.run( | |
| # [cmd, "--version"], | |
| # capture_output=True, text=True | |
| # ) | |
| # if result.returncode == 0: | |
| # print(f"[DEBUG] javac found: {cmd}", flush=True) | |
| # return cmd | |
| # except FileNotFoundError: | |
| # continue | |
| # return None | |
| # def find_junit_jar(): | |
| # """JUnit jar file dhundho""" | |
| # candidates = [ | |
| # "/Users/vidhikoul/Desktop/UnitTestCaseGenerator/tcgenerator/junit-platform-console-standalone.jar", | |
| # "/app/junit-platform-console-standalone.jar", | |
| # "/app/env/junit-platform-console-standalone.jar", | |
| # "/junit-platform-console-standalone.jar", | |
| # ] | |
| # for path in candidates: | |
| # if os.path.exists(path): | |
| # print(f"[DEBUG] JUnit jar found: {path}", flush=True) | |
| # return path | |
| # return None | |
| # # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # # JUNIT SANDBOX | |
| # # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # def run_junit_tests(source_code: str, test_code: str, | |
| # source_class: str, test_class: str, | |
| # timeout: int = 30): | |
| # """ | |
| # Java source + JUnit test compile karke run karo | |
| # Returns (passed, failed, total, error) | |
| # """ | |
| # java = find_java() | |
| # javac = find_javac() | |
| # junit = find_junit_jar() | |
| # if not java or not javac: | |
| # return 0, 0, 0, "Java not found! Install JDK in Docker." | |
| # if not junit: | |
| # return 0, 0, 0, "JUnit jar not found at /app/junit-platform-console-standalone.jar" | |
| # tmpdir = tempfile.mkdtemp() | |
| # try: | |
| # # Step 1 β Java files likho | |
| # source_file = os.path.join(tmpdir, f"{source_class}.java") | |
| # test_file = os.path.join(tmpdir, f"{test_class}.java") | |
| # with open(source_file, "w") as f: | |
| # f.write(source_code) | |
| # with open(test_file, "w") as f: | |
| # f.write(test_code) | |
| # # Step 2 β Compile karo | |
| # compile_result = subprocess.run( | |
| # [javac, "-cp", junit, source_file, test_file], | |
| # capture_output=True, text=True, | |
| # cwd=tmpdir, timeout=30 | |
| # ) | |
| # if compile_result.returncode != 0: | |
| # err = compile_result.stderr[:500] | |
| # print(f"[DEBUG] Compile error: {err}", flush=True) | |
| # return 0, 0, 0, f"Compile error: {err}" | |
| # run_result = subprocess.run( | |
| # [java, "-jar", junit, "-cp", tmpdir, | |
| # "--select-class", test_class, "--details", "summary"], | |
| # capture_output=True, | |
| # text=True, | |
| # cwd=tmpdir, | |
| # timeout=timeout, | |
| # ) | |
| # output = run_result.stdout + run_result.stderr | |
| # print(f"[DEBUG] JUnit output: {output[:300]}", flush=True) | |
| # # Step 4 β Parse results | |
| # passed, failed = parse_junit_output(output) | |
| # total = passed + failed | |
| # error = None if run_result.returncode == 0 else output[-400:] | |
| # return passed, failed, total, error | |
| # except subprocess.TimeoutExpired: | |
| # return 0, 0, 0, "Timeout: tests took too long" | |
| # except Exception as e: | |
| # return 0, 0, 0, str(e) | |
| # finally: | |
| # shutil.rmtree(tmpdir, ignore_errors=True) | |
| # def parse_junit_output(output: str): | |
| # """JUnit 5 output se passed/failed count nikalo""" | |
| # passed = failed = 0 | |
| # # Format: "3 tests successful" | |
| # p = re.findall(r'(\d+)\s+tests?\s+successful', output, re.IGNORECASE) | |
| # if p: | |
| # passed = int(p[0]) | |
| # # Format: "2 tests failed" | |
| # f = re.findall(r'(\d+)\s+tests?\s+failed', output, re.IGNORECASE) | |
| # if f: | |
| # failed = int(f[0]) | |
| # # Alternate format: "[ OK ]" style | |
| # if passed == 0 and failed == 0: | |
| # passed = len(re.findall(r'\[\s*OK\s*\]', output)) | |
| # failed = len(re.findall(r'\[\s*FAILED\s*\]', output)) | |
| # return passed, failed | |
| # # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # # REWARD | |
| # # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # def compute_reward(passed, total, test_code, task_cfg): | |
| # if total == 0: | |
| # return 0.0 | |
| # base = (passed / total) * 0.7 | |
| # quantity_bonus = 0.1 if total >= task_cfg["expected_min_tests"] else 0.0 | |
| # edge_bonus = 0.0 | |
| # if task_cfg["requires_edge_cases"]: | |
| # edge_keywords = ["empty", "zero", "null", "negative", "0", "[]", '""'] | |
| # if any(kw in test_code.lower() for kw in edge_keywords): | |
| # edge_bonus = 0.1 | |
| # exception_bonus = 0.0 | |
| # if task_cfg["requires_exception"]: | |
| # if "assertthrows" in test_code.lower() or "expected=" in test_code.lower(): | |
| # exception_bonus = 0.1 | |
| # return round(min(base + quantity_bonus + edge_bonus + exception_bonus, 1.0), 3) | |
| # # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # # ENVIRONMENT CLASS | |
| # # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # class TcgeneratorEnvironment(Environment): | |
| # SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| # def __init__(self, task: str = "test-easy"): | |
| # self._state = State(episode_id=str(uuid4()), step_count=0) | |
| # if task not in TASKS: | |
| # raise ValueError(f"Unknown task '{task}'. Choose from: {list(TASKS.keys())}") | |
| # self.task_name = task | |
| # self.task_cfg = TASKS[task] | |
| # self._step_count = 0 | |
| # self._max_steps = 6 | |
| # self._best_reward = 0.0 | |
| # def reset(self) -> TcgeneratorObservation: | |
| # self._state = State(episode_id=str(uuid4()), step_count=0) | |
| # self._step_count = 0 | |
| # self._best_reward = 0.0 | |
| # return TcgeneratorObservation( | |
| # source_code=self.task_cfg["source_code"], | |
| # task_hint=self.task_cfg["task_hint"], | |
| # passed=0, | |
| # failed=0, | |
| # total=0, | |
| # error=None, | |
| # reward=0.0, | |
| # done=False, | |
| # ) | |
| # def step(self, action: TcgeneratorAction) -> TcgeneratorObservation: | |
| # self._state.step_count += 1 | |
| # self._step_count += 1 | |
| # done = self._step_count >= self._max_steps | |
| # passed, failed, total, error = run_junit_tests( | |
| # source_code = self.task_cfg["source_code"], | |
| # test_code = action.test_code, | |
| # source_class = self.task_cfg["source_class"], | |
| # test_class = self.task_cfg["test_class"], | |
| # ) | |
| # reward = compute_reward(passed, total, action.test_code, self.task_cfg) | |
| # self._best_reward = max(self._best_reward, reward) | |
| # if reward >= 0.95: | |
| # done = True | |
| # return TcgeneratorObservation( | |
| # source_code=self.task_cfg["source_code"], | |
| # task_hint=self.task_cfg["task_hint"], | |
| # passed=passed, | |
| # failed=failed, | |
| # total=total, | |
| # error=error, | |
| # reward=reward, | |
| # done=done, | |
| # ) | |
| # @property | |
| # def state(self) -> State: | |
| # return self._state | |
| if __name__ == "__main__": | |
| print(f"Passed: {passed}, Failed: {failed}, Total: {total}, Error: {error}") | |
| # # java_code = textwrap.dedent(""" | |
| # # public class Calculator { | |
| # # public int add(int a, int b) { | |
| # # return a + b; | |
| # # } | |
| # # public boolean isPalindrome(String s) { | |
| # # String clean = s.toLowerCase().replace(" ", ""); | |
| # # String reversed = new StringBuilder(clean).reverse().toString(); | |
| # # return clean.equals(reversed); | |
| # # } | |
| # # public double celsiusToFahrenheit(double c) { | |
| # # return c * 9.0 / 5.0 + 32; | |
| # # } | |
| # # } | |
| # # """).strip() | |
| # # test_code = textwrap.dedent(""" | |
| # # import org.junit.Test; | |
| # # import static org.junit.Assert.*; | |
| # # public class CalculatorTest { | |
| # # private Calculator calculator = new Calculator(); | |
| # # @Test | |
| # # public void testAdd() { | |
| # # assertEquals(5, calculator.add(2, 3)); | |
| # # assertEquals(-1, calculator.add(-2, 1)); | |
| # # assertEquals(0, calculator.add(0, 0)); | |
| # # } | |
| # # @Test | |
| # # public void testIsPalindrome() { | |
| # # assertTrue(calculator.isPalindrome("A man a plan a canal Panama")); | |
| # # assertTrue(calculator.isPalindrome("racecar")); | |
| # # assertFalse(calculator.isPalindrome("hello")); | |
| # # } | |
| # # @Test | |
| # # public void testCelsiusToFahrenheit() { | |
| # # assertEquals(32.0, calculator.celsiusToFahrenheit(0), 0.001); | |
| # # assertEquals(212.0, calculator.celsiusToFahrenheit(100), 0.001); | |
| # # assertEquals(98.6, calculator.celsiusToFahrenheit(37), 0.001); | |
| # # } | |
| # # } | |
| # # """).strip() | |
| # # passed, failed, total, error = run_junit_tests( | |
| # # source_code=java_code, | |
| # # test_code=test_code, | |
| # # source_class="Calculator", | |
| # # test_class="CalculatorTest" | |
| # # ) | |
| # print(f"Passed: {passed}, Failed: {failed}, Total: {total}, Error: {error}") |