Spaces:
Sleeping
Sleeping
File size: 2,956 Bytes
570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import json
import pytest
from adapters.llm.openai_provider import OpenAIProvider
# Helper class to fake the completion object returned by OpenAI SDK
class FakeCompletion:
def __init__(self, content: str, prompt_tokens=5, completion_tokens=7):
self.choices = [
type("Choice", (), {"message": type("Msg", (), {"content": content})})
]
self.usage = type(
"Usage",
(),
{"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens},
)
# --- Case 1: clean valid JSON --------------------------------------------------
def test_generate_sql_valid_json(monkeypatch):
provider = OpenAIProvider()
fake_content = json.dumps(
{"sql": "SELECT * FROM singer;", "rationale": "List all singers."}
)
fake_completion = FakeCompletion(fake_content)
# Monkeypatch client.chat.completions.create
def fake_create(*args, **kwargs):
return fake_completion
monkeypatch.setattr(provider.client.chat.completions, "create", fake_create)
sql, rationale, t_in, t_out, cost = provider.generate_sql(
user_query="show all singers",
schema_preview="CREATE TABLE singer(id int, name text);",
plan_text="-- plan --",
clarify_answers={},
)
assert sql.strip().lower().startswith("select")
assert "singer" in sql.lower()
assert "list" in rationale.lower()
assert t_in == 5 and t_out == 7
assert isinstance(cost, float)
# --- Case 2: malformed JSON with extra text (should still recover) ------------
def test_generate_sql_recover_from_partial_json(monkeypatch):
provider = OpenAIProvider()
# invalid JSON with text around it
fake_content = 'Here is the result:\n{ "sql": "SELECT * FROM users;", "rationale": "list users" }\nThanks!'
fake_completion = FakeCompletion(fake_content)
def fake_create(*args, **kwargs):
return fake_completion
monkeypatch.setattr(provider.client.chat.completions, "create", fake_create)
sql, rationale, *_ = provider.generate_sql(
user_query="show all users",
schema_preview="CREATE TABLE users(id int, name text);",
plan_text="-- plan --",
)
assert sql.lower().startswith("select")
assert "user" in sql.lower()
assert "list" in rationale.lower()
# --- Case 3: completely invalid JSON (should raise ValueError) ----------------
def test_generate_sql_invalid_json(monkeypatch):
provider = OpenAIProvider()
fake_content = "This is nonsense output without braces"
fake_completion = FakeCompletion(fake_content)
def fake_create(*args, **kwargs):
return fake_completion
monkeypatch.setattr(provider.client.chat.completions, "create", fake_create)
with pytest.raises(ValueError):
provider.generate_sql(
user_query="show X",
schema_preview="CREATE TABLE t(id int);",
plan_text="-- plan --",
)
|