BtB-ExpC commited on
Commit
ca454c0
·
1 Parent(s): 38e8d65
chains/diagnoser_chain.py CHANGED
@@ -25,10 +25,8 @@ class DiagnoserChain(BaseModel):
25
  # --- Step 2: Generate a diagnosis using the standardized exercise ---
26
  prompt_diagnose = await self.template_diagnose.aformat_prompt(standardized_exercise=standardized_exercise)
27
  diagnose_messages = prompt_diagnose.to_messages()
28
- diagnosis = ""
29
- async for token in self.llm_diagnose.astream(diagnose_messages):
30
- diagnosis += token
31
- # Here you could, for example, update a UI element if you were streaming tokens to the frontend.
32
  return diagnosis
33
 
34
  class Config:
 
25
  # --- Step 2: Generate a diagnosis using the standardized exercise ---
26
  prompt_diagnose = await self.template_diagnose.aformat_prompt(standardized_exercise=standardized_exercise)
27
  diagnose_messages = prompt_diagnose.to_messages()
28
+ diagnosis = await self.llm_diagnose.astream(diagnose_messages)
29
+
 
 
30
  return diagnosis
31
 
32
  class Config:
chains/distractors_chain.py CHANGED
@@ -6,10 +6,11 @@ from config.exercise_standardizer import standardize_exercise
6
 
7
 
8
  class DistractorsChain(BaseModel):
9
- llm_standardize: Any # Fixed LLM for step 1
10
  template_standardize: ChatPromptTemplate
11
- template: ChatPromptTemplate
12
- llm: Any # User-selectable LLM for step 2
 
 
13
 
14
  async def run(self, user_query: str, exercise_format: str) -> str:
15
  """
@@ -25,20 +26,9 @@ class DistractorsChain(BaseModel):
25
  # --- Step 2: Generate new distractors using the standardized exercise ---
26
  prompt_distractors = await self.template_distractors.aformat_prompt(standardized_exercise=standardized_exercise)
27
  distractors_messages = prompt_distractors.to_messages()
28
- distractors = ""
29
- async for token in self.llm_distr.astream(distractors_messages):
30
- distractors += token
31
- # Here you could, for example, update a UI element if you were streaming tokens to the frontend.
32
- return distractors
33
-
34
-
35
 
36
-
37
-
38
- prompt = await self.template.aformat_prompt(user_input=user_query)
39
- messages = prompt.to_messages()
40
- result = await self.llm.ainvoke(messages)
41
- return result
42
 
43
  class Config:
44
  arbitrary_types_allowed = True
 
6
 
7
 
8
  class DistractorsChain(BaseModel):
 
9
  template_standardize: ChatPromptTemplate
10
+ template_distr: ChatPromptTemplate
11
+ llm_standardize: Any # Fixed LLM for step 1
12
+ llm_distr: Any # User-selectable LLM for step 2
13
+
14
 
15
  async def run(self, user_query: str, exercise_format: str) -> str:
16
  """
 
26
  # --- Step 2: Generate new distractors using the standardized exercise ---
27
  prompt_distractors = await self.template_distractors.aformat_prompt(standardized_exercise=standardized_exercise)
28
  distractors_messages = prompt_distractors.to_messages()
29
+ distractors = await self.llm_distr.astream(distractors_messages)
 
 
 
 
 
 
30
 
31
+ return distractors
 
 
 
 
 
32
 
33
  class Config:
34
  arbitrary_types_allowed = True
config/exercise_standardizer.py CHANGED
@@ -22,10 +22,6 @@ async def standardize_exercise(user_query: str, exercise_format: str, template:
22
  )
23
 
24
  std_messages = prompt_std.to_messages()
25
-
26
- # Stream tokens to construct the standardized response
27
- standardized_exercise = ""
28
- async for token in llm.astream(std_messages):
29
- standardized_exercise += token
30
 
31
  return standardized_exercise
 
22
  )
23
 
24
  std_messages = prompt_std.to_messages()
25
+ standardized_exercise = await llm.ainvoke(std_messages)
 
 
 
 
26
 
27
  return standardized_exercise
utils/streaming.py DELETED
@@ -1,35 +0,0 @@
1
- # utils/streaming.py
2
- import os
3
- import asyncio
4
- from huggingface_hub import AsyncInferenceClient
5
-
6
-
7
- async def stream_chat_completion(messages, model_name: str, max_tokens: int = 1024):
8
- """
9
- Stream tokens from a Hugging Face Inference endpoint.
10
-
11
- Args:
12
- messages (list[dict]): A list of message dictionaries, e.g.:
13
- [{"role": "system", "content": "You are a helpful assistant."},
14
- {"role": "user", "content": "Count to 10"}]
15
- model_name (str): The identifier for the model (used in the base_url).
16
- max_tokens (int): Maximum tokens to generate.
17
-
18
- Yields:
19
- str: Tokens as they are generated.
20
- """
21
- # Construct a base URL that points to the model’s endpoint.
22
- base_url = f"https://api-inference.huggingface.co/models/{model_name}"
23
- token = os.getenv("HF_API_TOKEN")
24
- client = AsyncInferenceClient(base_url=base_url, token=token)
25
-
26
- stream = await client.chat.completions.create(
27
- messages=messages,
28
- stream=True,
29
- max_tokens=max_tokens,
30
- )
31
-
32
- async for chunk in stream:
33
- # Each chunk is expected to have a structure where the generated text is in:
34
- # chunk.choices[0].delta.content
35
- yield chunk.choices[0].delta.content or ""