jashdoshi77 commited on
Commit
1b2fa72
Β·
1 Parent(s): 559689a

added logging and fixed dspY error

Browse files
Files changed (3) hide show
  1. ai/groq_setup.py +17 -11
  2. ai/pipeline.py +5 -5
  3. app.py +24 -0
ai/groq_setup.py CHANGED
@@ -8,27 +8,33 @@ import dspy
8
  import config
9
 
10
 
 
 
 
 
 
 
 
 
 
 
11
  def get_lm(provider: str = "groq") -> dspy.LM:
12
- """Return a configured DSPy language-model instance.
 
 
 
13
 
14
  Parameters
15
  ----------
16
  provider : "groq" | "openai"
17
  """
18
  if provider == "openai":
19
- lm = dspy.LM(
20
  model=f"openai/{config.OPENAI_MODEL}",
21
  api_key=config.OPENAI_API_KEY,
22
  max_tokens=4096,
23
  temperature=0.2,
24
  )
25
- else: # default: groq
26
- lm = dspy.LM(
27
- model=f"groq/{config.GROQ_MODEL}",
28
- api_key=config.GROQ_API_KEY,
29
- max_tokens=4096,
30
- temperature=0.2,
31
- )
32
 
33
- dspy.configure(lm=lm)
34
- return lm
 
8
  import config
9
 
10
 
11
+ # Configure DSPy ONCE at import time, on the main thread, with Groq as default.
12
+ _default_lm = dspy.LM(
13
+ model=f"groq/{config.GROQ_MODEL}",
14
+ api_key=config.GROQ_API_KEY,
15
+ max_tokens=4096,
16
+ temperature=0.2,
17
+ )
18
+ dspy.configure(lm=_default_lm)
19
+
20
+
21
  def get_lm(provider: str = "groq") -> dspy.LM:
22
+ """Return a DSPy language-model instance for the requested provider.
23
+
24
+ This does NOT call dspy.configure to avoid thread-safety issues;
25
+ the global settings are configured once at import using Groq.
26
 
27
  Parameters
28
  ----------
29
  provider : "groq" | "openai"
30
  """
31
  if provider == "openai":
32
+ return dspy.LM(
33
  model=f"openai/{config.OPENAI_MODEL}",
34
  api_key=config.OPENAI_API_KEY,
35
  max_tokens=4096,
36
  temperature=0.2,
37
  )
 
 
 
 
 
 
 
38
 
39
+ # Default / Groq: reuse the global LM so we share configuration.
40
+ return _default_lm
ai/pipeline.py CHANGED
@@ -39,11 +39,11 @@ class SQLAnalystPipeline:
39
  self.provider = provider
40
  self._lm = get_lm(provider)
41
 
42
- # DSPy predict modules
43
- self.analyze = dspy.Predict(AnalyzeAndPlan)
44
- self.generate_sql = dspy.Predict(SQLGeneration)
45
- self.interpret = dspy.Predict(InterpretAndInsight)
46
- self.repair = dspy.Predict(SQLRepair)
47
 
48
  # ── public API ──────────────────────────────────────────────────────
49
 
 
39
  self.provider = provider
40
  self._lm = get_lm(provider)
41
 
42
+ # DSPy predict modules β€” each bound to the chosen LM instance
43
+ self.analyze = dspy.Predict(AnalyzeAndPlan, lm=self._lm)
44
+ self.generate_sql = dspy.Predict(SQLGeneration, lm=self._lm)
45
+ self.interpret = dspy.Predict(InterpretAndInsight, lm=self._lm)
46
+ self.repair = dspy.Predict(SQLRepair, lm=self._lm)
47
 
48
  # ── public API ──────────────────────────────────────────────────────
49
 
app.py CHANGED
@@ -10,6 +10,7 @@ from fastapi.staticfiles import StaticFiles
10
  from pydantic import BaseModel
11
 
12
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
 
13
 
14
  app = FastAPI(title="AI SQL Analyst", version="1.0.0")
15
 
@@ -57,12 +58,24 @@ def chat_endpoint(req: QuestionRequest):
57
  from ai.pipeline import SQLAnalystPipeline
58
  from db.memory import get_recent_history, add_turn
59
 
 
 
 
 
 
 
 
60
  conversation_id = req.conversation_id or "default"
61
 
62
  history = get_recent_history(conversation_id, limit=5)
63
 
64
  # Augment the question with recent conversation context
65
  if history:
 
 
 
 
 
66
  history_lines: list[str] = ["You are in a multi-turn conversation. Here are the recent exchanges:"]
67
  for turn in history:
68
  history_lines.append(f"User: {turn['question']}")
@@ -70,11 +83,22 @@ def chat_endpoint(req: QuestionRequest):
70
  history_lines.append(f"Now the user asks: {req.question}")
71
  question_with_context = "\n".join(history_lines)
72
  else:
 
 
 
 
73
  question_with_context = req.question
74
 
75
  pipeline = SQLAnalystPipeline(provider=req.provider)
76
  result = pipeline.run(question_with_context)
77
 
 
 
 
 
 
 
 
78
  # Persist this turn for future context
79
  add_turn(conversation_id, req.question, result["answer"], result["sql"])
80
 
 
10
  from pydantic import BaseModel
11
 
12
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
13
+ logger = logging.getLogger("api")
14
 
15
  app = FastAPI(title="AI SQL Analyst", version="1.0.0")
16
 
 
58
  from ai.pipeline import SQLAnalystPipeline
59
  from db.memory import get_recent_history, add_turn
60
 
61
+ logger.info(
62
+ "CHAT request | provider=%s | conversation_id=%s | question=%s",
63
+ req.provider,
64
+ req.conversation_id or "default",
65
+ req.question,
66
+ )
67
+
68
  conversation_id = req.conversation_id or "default"
69
 
70
  history = get_recent_history(conversation_id, limit=5)
71
 
72
  # Augment the question with recent conversation context
73
  if history:
74
+ logger.info(
75
+ "CHAT context | conversation_id=%s | history_turns=%d",
76
+ conversation_id,
77
+ len(history),
78
+ )
79
  history_lines: list[str] = ["You are in a multi-turn conversation. Here are the recent exchanges:"]
80
  for turn in history:
81
  history_lines.append(f"User: {turn['question']}")
 
83
  history_lines.append(f"Now the user asks: {req.question}")
84
  question_with_context = "\n".join(history_lines)
85
  else:
86
+ logger.info(
87
+ "CHAT context | conversation_id=%s | history_turns=0 (no prior context used)",
88
+ conversation_id,
89
+ )
90
  question_with_context = req.question
91
 
92
  pipeline = SQLAnalystPipeline(provider=req.provider)
93
  result = pipeline.run(question_with_context)
94
 
95
+ logger.info(
96
+ "CHAT result | conversation_id=%s | used_context=%s | sql_preview=%s",
97
+ conversation_id,
98
+ "yes" if history else "no",
99
+ (result.get("sql") or "").replace("\n", " ")[:200],
100
+ )
101
+
102
  # Persist this turn for future context
103
  add_turn(conversation_id, req.question, result["answer"], result["sql"])
104