Spaces:
Sleeping
Sleeping
| import json | |
| import pytest | |
| from adapters.llm.openai_provider import OpenAIProvider | |
| # Helper class to fake the completion object returned by OpenAI SDK | |
| class FakeCompletion: | |
| def __init__(self, content: str, prompt_tokens=5, completion_tokens=7): | |
| self.choices = [ | |
| type("Choice", (), {"message": type("Msg", (), {"content": content})}) | |
| ] | |
| self.usage = type( | |
| "Usage", | |
| (), | |
| {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens}, | |
| ) | |
| # --- Case 1: clean valid JSON -------------------------------------------------- | |
| def test_generate_sql_valid_json(monkeypatch): | |
| provider = OpenAIProvider() | |
| fake_content = json.dumps( | |
| {"sql": "SELECT * FROM singer;", "rationale": "List all singers."} | |
| ) | |
| fake_completion = FakeCompletion(fake_content) | |
| # Monkeypatch client.chat.completions.create | |
| def fake_create(*args, **kwargs): | |
| return fake_completion | |
| monkeypatch.setattr(provider.client.chat.completions, "create", fake_create) | |
| sql, rationale, t_in, t_out, cost = provider.generate_sql( | |
| user_query="show all singers", | |
| schema_preview="CREATE TABLE singer(id int, name text);", | |
| plan_text="-- plan --", | |
| clarify_answers={}, | |
| ) | |
| assert sql.strip().lower().startswith("select") | |
| assert "singer" in sql.lower() | |
| assert "list" in rationale.lower() | |
| assert t_in == 5 and t_out == 7 | |
| assert isinstance(cost, float) | |
| # --- Case 2: malformed JSON with extra text (should still recover) ------------ | |
| def test_generate_sql_recover_from_partial_json(monkeypatch): | |
| provider = OpenAIProvider() | |
| # invalid JSON with text around it | |
| fake_content = 'Here is the result:\n{ "sql": "SELECT * FROM users;", "rationale": "list users" }\nThanks!' | |
| fake_completion = FakeCompletion(fake_content) | |
| def fake_create(*args, **kwargs): | |
| return fake_completion | |
| monkeypatch.setattr(provider.client.chat.completions, "create", fake_create) | |
| sql, rationale, *_ = provider.generate_sql( | |
| user_query="show all users", | |
| schema_preview="CREATE TABLE users(id int, name text);", | |
| plan_text="-- plan --", | |
| ) | |
| assert sql.lower().startswith("select") | |
| assert "user" in sql.lower() | |
| assert "list" in rationale.lower() | |
| # --- Case 3: completely invalid JSON (should raise ValueError) ---------------- | |
| def test_generate_sql_invalid_json(monkeypatch): | |
| provider = OpenAIProvider() | |
| fake_content = "This is nonsense output without braces" | |
| fake_completion = FakeCompletion(fake_content) | |
| def fake_create(*args, **kwargs): | |
| return fake_completion | |
| monkeypatch.setattr(provider.client.chat.completions, "create", fake_create) | |
| with pytest.raises(ValueError): | |
| provider.generate_sql( | |
| user_query="show X", | |
| schema_preview="CREATE TABLE t(id int);", | |
| plan_text="-- plan --", | |
| ) | |