import json import pytest from adapters.llm.openai_provider import OpenAIProvider # Helper class to fake the completion object returned by OpenAI SDK class FakeCompletion: def __init__(self, content: str, prompt_tokens=5, completion_tokens=7): self.choices = [type("Choice", (), {"message": type("Msg", (), {"content": content})})] self.usage = type("Usage", (), { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens }) # --- Case 1: clean valid JSON -------------------------------------------------- def test_generate_sql_valid_json(monkeypatch): provider = OpenAIProvider() fake_content = json.dumps({ "sql": "SELECT * FROM singer;", "rationale": "List all singers." }) fake_completion = FakeCompletion(fake_content) # Monkeypatch client.chat.completions.create def fake_create(*args, **kwargs): return fake_completion monkeypatch.setattr(provider.client.chat.completions, "create", fake_create) sql, rationale, t_in, t_out, cost = provider.generate_sql( user_query="show all singers", schema_preview="CREATE TABLE singer(id int, name text);", plan_text="-- plan --", clarify_answers={} ) assert sql.strip().lower().startswith("select") assert "singer" in sql.lower() assert "list" in rationale.lower() assert t_in == 5 and t_out == 7 assert isinstance(cost, float) # --- Case 2: malformed JSON with extra text (should still recover) ------------ def test_generate_sql_recover_from_partial_json(monkeypatch): provider = OpenAIProvider() # invalid JSON with text around it fake_content = "Here is the result:\n{ \"sql\": \"SELECT * FROM users;\", \"rationale\": \"list users\" }\nThanks!" fake_completion = FakeCompletion(fake_content) def fake_create(*args, **kwargs): return fake_completion monkeypatch.setattr(provider.client.chat.completions, "create", fake_create) sql, rationale, *_ = provider.generate_sql( user_query="show all users", schema_preview="CREATE TABLE users(id int, name text);", plan_text="-- plan --" ) assert sql.lower().startswith("select") assert "user" in sql.lower() assert "list" in rationale.lower() # --- Case 3: completely invalid JSON (should raise ValueError) ---------------- def test_generate_sql_invalid_json(monkeypatch): provider = OpenAIProvider() fake_content = "This is nonsense output without braces" fake_completion = FakeCompletion(fake_content) def fake_create(*args, **kwargs): return fake_completion monkeypatch.setattr(provider.client.chat.completions, "create", fake_create) with pytest.raises(ValueError): provider.generate_sql( user_query="show X", schema_preview="CREATE TABLE t(id int);", plan_text="-- plan --" )