jashdoshi77 commited on
Commit
37a0153
Β·
1 Parent(s): 1b2fa72

fixed dspy errors

Browse files
Files changed (2) hide show
  1. ai/groq_setup.py +21 -28
  2. ai/pipeline.py +6 -5
ai/groq_setup.py CHANGED
@@ -1,40 +1,33 @@
1
- """DSPy language model setup for Groq and OpenAI.
2
 
3
- Provides a factory function to create the right LM based on the
4
- user-selected provider.
5
  """
6
 
7
  import dspy
8
  import config
9
 
10
 
11
- # Configure DSPy ONCE at import time, on the main thread, with Groq as default.
12
- _default_lm = dspy.LM(
13
- model=f"groq/{config.GROQ_MODEL}",
14
- api_key=config.GROQ_API_KEY,
15
- max_tokens=4096,
16
- temperature=0.2,
17
- )
18
- dspy.configure(lm=_default_lm)
 
 
19
 
20
 
21
- def get_lm(provider: str = "groq") -> dspy.LM:
22
- """Return a DSPy language-model instance for the requested provider.
23
 
24
- This does NOT call dspy.configure to avoid thread-safety issues;
25
- the global settings are configured once at import using Groq.
26
 
27
- Parameters
28
- ----------
29
- provider : "groq" | "openai"
30
  """
31
- if provider == "openai":
32
- return dspy.LM(
33
- model=f"openai/{config.OPENAI_MODEL}",
34
- api_key=config.OPENAI_API_KEY,
35
- max_tokens=4096,
36
- temperature=0.2,
37
- )
38
-
39
- # Default / Groq: reuse the global LM so we share configuration.
40
- return _default_lm
 
1
+ """DSPy language model setup for Groq (and optionally OpenAI).
2
 
3
+ We configure DSPy ONCE at import time on the main thread to avoid
4
+ thread-safety issues with FastAPI's worker threads.
5
  """
6
 
7
  import dspy
8
  import config
9
 
10
 
11
+ def _configure_default_lm() -> dspy.LM:
12
+ """Configure the global DSPy LM once and return it."""
13
+ lm = dspy.LM(
14
+ model=f"groq/{config.GROQ_MODEL}",
15
+ api_key=config.GROQ_API_KEY,
16
+ max_tokens=4096,
17
+ temperature=0.2,
18
+ )
19
+ dspy.configure(lm=lm)
20
+ return lm
21
 
22
 
23
+ _DEFAULT_LM = _configure_default_lm()
24
+
25
 
26
+ def get_lm(provider: str = "groq") -> dspy.LM:
27
+ """Return the LM instance to use.
28
 
29
+ NOTE: To keep things simple and robust inside the web server, we always
30
+ use the globally configured LM. The `provider` argument is accepted for
31
+ future extension but currently ignored.
32
  """
33
+ return _DEFAULT_LM
 
 
 
 
 
 
 
 
 
ai/pipeline.py CHANGED
@@ -36,14 +36,15 @@ class SQLAnalystPipeline:
36
  """End-to-end reasoning pipeline: question β†’ SQL β†’ results β†’ insights."""
37
 
38
  def __init__(self, provider: str = "groq"):
 
39
  self.provider = provider
40
  self._lm = get_lm(provider)
41
 
42
- # DSPy predict modules β€” each bound to the chosen LM instance
43
- self.analyze = dspy.Predict(AnalyzeAndPlan, lm=self._lm)
44
- self.generate_sql = dspy.Predict(SQLGeneration, lm=self._lm)
45
- self.interpret = dspy.Predict(InterpretAndInsight, lm=self._lm)
46
- self.repair = dspy.Predict(SQLRepair, lm=self._lm)
47
 
48
  # ── public API ──────────────────────────────────────────────────────
49
 
 
36
  """End-to-end reasoning pipeline: question β†’ SQL β†’ results β†’ insights."""
37
 
38
  def __init__(self, provider: str = "groq"):
39
+ # For now we always use the globally configured LM (Groq).
40
  self.provider = provider
41
  self._lm = get_lm(provider)
42
 
43
+ # DSPy predict modules β€” rely on global dspy.settings
44
+ self.analyze = dspy.Predict(AnalyzeAndPlan)
45
+ self.generate_sql = dspy.Predict(SQLGeneration)
46
+ self.interpret = dspy.Predict(InterpretAndInsight)
47
+ self.repair = dspy.Predict(SQLRepair)
48
 
49
  # ── public API ──────────────────────────────────────────────────────
50