Souravdanyal commited on
Commit
03314f6
·
1 Parent(s): 7bb5ffa

Fix JSON parsing - handle control characters in LLM response

Browse files
Files changed (1) hide show
  1. inference.py +33 -8
inference.py CHANGED
@@ -96,17 +96,42 @@ def call_llm(buggy_code, instructions, difficulty, feedback=None, attempt=1, pre
96
  temperature=0.1 if attempt == 1 else 0.4,
97
  )
98
  raw = resp.choices[0].message.content.strip()
99
- # Clean markdown fences
100
- if "```" in raw:
101
- raw = raw.split("```")[1] if raw.startswith("```") else raw
102
- if raw.startswith("json\n"):
103
- raw = raw[5:]
104
- # Find JSON object
 
 
 
 
105
  start = raw.find("{")
106
  end = raw.rfind("}") + 1
107
  if start >= 0 and end > start:
108
  raw = raw[start:end]
109
- parsed = json.loads(raw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  return {"fixed_code": parsed.get("fixed_code", ""), "explanation": parsed.get("explanation")}
111
  except Exception as e:
112
  print(f"# LLM error: {e}", file=sys.stderr)
@@ -186,4 +211,4 @@ def main():
186
  print(f"# SUMMARY: {sum(successes)}/{len(diffs)} tasks solved | avg_reward={avg}", flush=True)
187
 
188
  if __name__ == "__main__":
189
- main()
 
96
  temperature=0.1 if attempt == 1 else 0.4,
97
  )
98
  raw = resp.choices[0].message.content.strip()
99
+
100
+ # Remove markdown fences
101
+ if "```json" in raw:
102
+ raw = raw.split("```json")[1].split("```")[0].strip()
103
+ elif "```" in raw:
104
+ raw = raw.split("```")[1].split("```")[0].strip()
105
+ if raw.startswith("json"):
106
+ raw = raw[4:].strip()
107
+
108
+ # Find JSON object boundaries
109
  start = raw.find("{")
110
  end = raw.rfind("}") + 1
111
  if start >= 0 and end > start:
112
  raw = raw[start:end]
113
+
114
+ # Try direct parse first
115
+ try:
116
+ parsed = json.loads(raw)
117
+ except json.JSONDecodeError:
118
+ # Fix control characters by replacing literal newlines inside strings
119
+ import re
120
+ # Replace actual newlines within JSON string values with \n escape
121
+ raw = re.sub(r'(?<!\\)\n', r'\\n', raw)
122
+ raw = re.sub(r'(?<!\\)\t', r'\\t', raw)
123
+ raw = re.sub(r'(?<!\\)\r', r'\\r', raw)
124
+ try:
125
+ parsed = json.loads(raw)
126
+ except json.JSONDecodeError:
127
+ # Last resort: extract fixed_code manually using regex
128
+ code_match = re.search(r'"fixed_code"\s*:\s*"(.*?)"(?=\s*[,}])', raw, re.DOTALL)
129
+ exp_match = re.search(r'"explanation"\s*:\s*"(.*?)"(?=\s*[,}])', raw, re.DOTALL)
130
+ if code_match:
131
+ code = code_match.group(1).encode().decode('unicode_escape') if '\\n' in code_match.group(1) else code_match.group(1)
132
+ return {"fixed_code": code, "explanation": exp_match.group(1) if exp_match else None}
133
+ raise
134
+
135
  return {"fixed_code": parsed.get("fixed_code", ""), "explanation": parsed.get("explanation")}
136
  except Exception as e:
137
  print(f"# LLM error: {e}", file=sys.stderr)
 
211
  print(f"# SUMMARY: {sum(successes)}/{len(diffs)} tasks solved | avg_reward={avg}", flush=True)
212
 
213
  if __name__ == "__main__":
214
+ main()