Spaces:
Running
Running
Commit Β·
1b2fa72
1
Parent(s): 559689a
added logging and fixed dspY error
Browse files- ai/groq_setup.py +17 -11
- ai/pipeline.py +5 -5
- app.py +24 -0
ai/groq_setup.py
CHANGED
|
@@ -8,27 +8,33 @@ import dspy
|
|
| 8 |
import config
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def get_lm(provider: str = "groq") -> dspy.LM:
|
| 12 |
-
"""Return a
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
Parameters
|
| 15 |
----------
|
| 16 |
provider : "groq" | "openai"
|
| 17 |
"""
|
| 18 |
if provider == "openai":
|
| 19 |
-
|
| 20 |
model=f"openai/{config.OPENAI_MODEL}",
|
| 21 |
api_key=config.OPENAI_API_KEY,
|
| 22 |
max_tokens=4096,
|
| 23 |
temperature=0.2,
|
| 24 |
)
|
| 25 |
-
else: # default: groq
|
| 26 |
-
lm = dspy.LM(
|
| 27 |
-
model=f"groq/{config.GROQ_MODEL}",
|
| 28 |
-
api_key=config.GROQ_API_KEY,
|
| 29 |
-
max_tokens=4096,
|
| 30 |
-
temperature=0.2,
|
| 31 |
-
)
|
| 32 |
|
| 33 |
-
|
| 34 |
-
return
|
|
|
|
| 8 |
import config
|
| 9 |
|
| 10 |
|
| 11 |
+
# Configure DSPy ONCE at import time, on the main thread, with Groq as default.
|
| 12 |
+
_default_lm = dspy.LM(
|
| 13 |
+
model=f"groq/{config.GROQ_MODEL}",
|
| 14 |
+
api_key=config.GROQ_API_KEY,
|
| 15 |
+
max_tokens=4096,
|
| 16 |
+
temperature=0.2,
|
| 17 |
+
)
|
| 18 |
+
dspy.configure(lm=_default_lm)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
def get_lm(provider: str = "groq") -> dspy.LM:
|
| 22 |
+
"""Return a DSPy language-model instance for the requested provider.
|
| 23 |
+
|
| 24 |
+
This does NOT call dspy.configure to avoid thread-safety issues;
|
| 25 |
+
the global settings are configured once at import using Groq.
|
| 26 |
|
| 27 |
Parameters
|
| 28 |
----------
|
| 29 |
provider : "groq" | "openai"
|
| 30 |
"""
|
| 31 |
if provider == "openai":
|
| 32 |
+
return dspy.LM(
|
| 33 |
model=f"openai/{config.OPENAI_MODEL}",
|
| 34 |
api_key=config.OPENAI_API_KEY,
|
| 35 |
max_tokens=4096,
|
| 36 |
temperature=0.2,
|
| 37 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
# Default / Groq: reuse the global LM so we share configuration.
|
| 40 |
+
return _default_lm
|
ai/pipeline.py
CHANGED
|
@@ -39,11 +39,11 @@ class SQLAnalystPipeline:
|
|
| 39 |
self.provider = provider
|
| 40 |
self._lm = get_lm(provider)
|
| 41 |
|
| 42 |
-
# DSPy predict modules
|
| 43 |
-
self.analyze = dspy.Predict(AnalyzeAndPlan)
|
| 44 |
-
self.generate_sql = dspy.Predict(SQLGeneration)
|
| 45 |
-
self.interpret = dspy.Predict(InterpretAndInsight)
|
| 46 |
-
self.repair = dspy.Predict(SQLRepair)
|
| 47 |
|
| 48 |
# ββ public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
|
|
|
|
| 39 |
self.provider = provider
|
| 40 |
self._lm = get_lm(provider)
|
| 41 |
|
| 42 |
+
# DSPy predict modules β each bound to the chosen LM instance
|
| 43 |
+
self.analyze = dspy.Predict(AnalyzeAndPlan, lm=self._lm)
|
| 44 |
+
self.generate_sql = dspy.Predict(SQLGeneration, lm=self._lm)
|
| 45 |
+
self.interpret = dspy.Predict(InterpretAndInsight, lm=self._lm)
|
| 46 |
+
self.repair = dspy.Predict(SQLRepair, lm=self._lm)
|
| 47 |
|
| 48 |
# ββ public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
|
app.py
CHANGED
|
@@ -10,6 +10,7 @@ from fastapi.staticfiles import StaticFiles
|
|
| 10 |
from pydantic import BaseModel
|
| 11 |
|
| 12 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
|
|
|
|
| 13 |
|
| 14 |
app = FastAPI(title="AI SQL Analyst", version="1.0.0")
|
| 15 |
|
|
@@ -57,12 +58,24 @@ def chat_endpoint(req: QuestionRequest):
|
|
| 57 |
from ai.pipeline import SQLAnalystPipeline
|
| 58 |
from db.memory import get_recent_history, add_turn
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
conversation_id = req.conversation_id or "default"
|
| 61 |
|
| 62 |
history = get_recent_history(conversation_id, limit=5)
|
| 63 |
|
| 64 |
# Augment the question with recent conversation context
|
| 65 |
if history:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
history_lines: list[str] = ["You are in a multi-turn conversation. Here are the recent exchanges:"]
|
| 67 |
for turn in history:
|
| 68 |
history_lines.append(f"User: {turn['question']}")
|
|
@@ -70,11 +83,22 @@ def chat_endpoint(req: QuestionRequest):
|
|
| 70 |
history_lines.append(f"Now the user asks: {req.question}")
|
| 71 |
question_with_context = "\n".join(history_lines)
|
| 72 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
question_with_context = req.question
|
| 74 |
|
| 75 |
pipeline = SQLAnalystPipeline(provider=req.provider)
|
| 76 |
result = pipeline.run(question_with_context)
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
# Persist this turn for future context
|
| 79 |
add_turn(conversation_id, req.question, result["answer"], result["sql"])
|
| 80 |
|
|
|
|
| 10 |
from pydantic import BaseModel
|
| 11 |
|
| 12 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
|
| 13 |
+
logger = logging.getLogger("api")
|
| 14 |
|
| 15 |
app = FastAPI(title="AI SQL Analyst", version="1.0.0")
|
| 16 |
|
|
|
|
| 58 |
from ai.pipeline import SQLAnalystPipeline
|
| 59 |
from db.memory import get_recent_history, add_turn
|
| 60 |
|
| 61 |
+
logger.info(
|
| 62 |
+
"CHAT request | provider=%s | conversation_id=%s | question=%s",
|
| 63 |
+
req.provider,
|
| 64 |
+
req.conversation_id or "default",
|
| 65 |
+
req.question,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
conversation_id = req.conversation_id or "default"
|
| 69 |
|
| 70 |
history = get_recent_history(conversation_id, limit=5)
|
| 71 |
|
| 72 |
# Augment the question with recent conversation context
|
| 73 |
if history:
|
| 74 |
+
logger.info(
|
| 75 |
+
"CHAT context | conversation_id=%s | history_turns=%d",
|
| 76 |
+
conversation_id,
|
| 77 |
+
len(history),
|
| 78 |
+
)
|
| 79 |
history_lines: list[str] = ["You are in a multi-turn conversation. Here are the recent exchanges:"]
|
| 80 |
for turn in history:
|
| 81 |
history_lines.append(f"User: {turn['question']}")
|
|
|
|
| 83 |
history_lines.append(f"Now the user asks: {req.question}")
|
| 84 |
question_with_context = "\n".join(history_lines)
|
| 85 |
else:
|
| 86 |
+
logger.info(
|
| 87 |
+
"CHAT context | conversation_id=%s | history_turns=0 (no prior context used)",
|
| 88 |
+
conversation_id,
|
| 89 |
+
)
|
| 90 |
question_with_context = req.question
|
| 91 |
|
| 92 |
pipeline = SQLAnalystPipeline(provider=req.provider)
|
| 93 |
result = pipeline.run(question_with_context)
|
| 94 |
|
| 95 |
+
logger.info(
|
| 96 |
+
"CHAT result | conversation_id=%s | used_context=%s | sql_preview=%s",
|
| 97 |
+
conversation_id,
|
| 98 |
+
"yes" if history else "no",
|
| 99 |
+
(result.get("sql") or "").replace("\n", " ")[:200],
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
# Persist this turn for future context
|
| 103 |
add_turn(conversation_id, req.question, result["answer"], result["sql"])
|
| 104 |
|