recovery
Browse files- chains/diagnoser_chain.py +2 -4
- chains/distractors_chain.py +6 -16
- config/exercise_standardizer.py +1 -5
- utils/streaming.py +0 -35
chains/diagnoser_chain.py
CHANGED
|
@@ -25,10 +25,8 @@ class DiagnoserChain(BaseModel):
|
|
| 25 |
# --- Step 2: Generate a diagnosis using the standardized exercise ---
|
| 26 |
prompt_diagnose = await self.template_diagnose.aformat_prompt(standardized_exercise=standardized_exercise)
|
| 27 |
diagnose_messages = prompt_diagnose.to_messages()
|
| 28 |
-
diagnosis =
|
| 29 |
-
|
| 30 |
-
diagnosis += token
|
| 31 |
-
# Here you could, for example, update a UI element if you were streaming tokens to the frontend.
|
| 32 |
return diagnosis
|
| 33 |
|
| 34 |
class Config:
|
|
|
|
| 25 |
# --- Step 2: Generate a diagnosis using the standardized exercise ---
|
| 26 |
prompt_diagnose = await self.template_diagnose.aformat_prompt(standardized_exercise=standardized_exercise)
|
| 27 |
diagnose_messages = prompt_diagnose.to_messages()
|
| 28 |
+
diagnosis = await self.llm_diagnose.astream(diagnose_messages)
|
| 29 |
+
|
|
|
|
|
|
|
| 30 |
return diagnosis
|
| 31 |
|
| 32 |
class Config:
|
chains/distractors_chain.py
CHANGED
|
@@ -6,10 +6,11 @@ from config.exercise_standardizer import standardize_exercise
|
|
| 6 |
|
| 7 |
|
| 8 |
class DistractorsChain(BaseModel):
|
| 9 |
-
llm_standardize: Any # Fixed LLM for step 1
|
| 10 |
template_standardize: ChatPromptTemplate
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
| 13 |
|
| 14 |
async def run(self, user_query: str, exercise_format: str) -> str:
|
| 15 |
"""
|
|
@@ -25,20 +26,9 @@ class DistractorsChain(BaseModel):
|
|
| 25 |
# --- Step 2: Generate new distractors using the standardized exercise ---
|
| 26 |
prompt_distractors = await self.template_distractors.aformat_prompt(standardized_exercise=standardized_exercise)
|
| 27 |
distractors_messages = prompt_distractors.to_messages()
|
| 28 |
-
distractors =
|
| 29 |
-
async for token in self.llm_distr.astream(distractors_messages):
|
| 30 |
-
distractors += token
|
| 31 |
-
# Here you could, for example, update a UI element if you were streaming tokens to the frontend.
|
| 32 |
-
return distractors
|
| 33 |
-
|
| 34 |
-
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
prompt = await self.template.aformat_prompt(user_input=user_query)
|
| 39 |
-
messages = prompt.to_messages()
|
| 40 |
-
result = await self.llm.ainvoke(messages)
|
| 41 |
-
return result
|
| 42 |
|
| 43 |
class Config:
|
| 44 |
arbitrary_types_allowed = True
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class DistractorsChain(BaseModel):
|
|
|
|
| 9 |
template_standardize: ChatPromptTemplate
|
| 10 |
+
template_distr: ChatPromptTemplate
|
| 11 |
+
llm_standardize: Any # Fixed LLM for step 1
|
| 12 |
+
llm_distr: Any # User-selectable LLM for step 2
|
| 13 |
+
|
| 14 |
|
| 15 |
async def run(self, user_query: str, exercise_format: str) -> str:
|
| 16 |
"""
|
|
|
|
| 26 |
# --- Step 2: Generate new distractors using the standardized exercise ---
|
| 27 |
prompt_distractors = await self.template_distractors.aformat_prompt(standardized_exercise=standardized_exercise)
|
| 28 |
distractors_messages = prompt_distractors.to_messages()
|
| 29 |
+
distractors = await self.llm_distr.astream(distractors_messages)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
return distractors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
class Config:
|
| 34 |
arbitrary_types_allowed = True
|
config/exercise_standardizer.py
CHANGED
|
@@ -22,10 +22,6 @@ async def standardize_exercise(user_query: str, exercise_format: str, template:
|
|
| 22 |
)
|
| 23 |
|
| 24 |
std_messages = prompt_std.to_messages()
|
| 25 |
-
|
| 26 |
-
# Stream tokens to construct the standardized response
|
| 27 |
-
standardized_exercise = ""
|
| 28 |
-
async for token in llm.astream(std_messages):
|
| 29 |
-
standardized_exercise += token
|
| 30 |
|
| 31 |
return standardized_exercise
|
|
|
|
| 22 |
)
|
| 23 |
|
| 24 |
std_messages = prompt_std.to_messages()
|
| 25 |
+
standardized_exercise = await llm.ainvoke(std_messages)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
return standardized_exercise
|
utils/streaming.py
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
# utils/streaming.py
|
| 2 |
-
import os
|
| 3 |
-
import asyncio
|
| 4 |
-
from huggingface_hub import AsyncInferenceClient
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
async def stream_chat_completion(messages, model_name: str, max_tokens: int = 1024):
|
| 8 |
-
"""
|
| 9 |
-
Stream tokens from a Hugging Face Inference endpoint.
|
| 10 |
-
|
| 11 |
-
Args:
|
| 12 |
-
messages (list[dict]): A list of message dictionaries, e.g.:
|
| 13 |
-
[{"role": "system", "content": "You are a helpful assistant."},
|
| 14 |
-
{"role": "user", "content": "Count to 10"}]
|
| 15 |
-
model_name (str): The identifier for the model (used in the base_url).
|
| 16 |
-
max_tokens (int): Maximum tokens to generate.
|
| 17 |
-
|
| 18 |
-
Yields:
|
| 19 |
-
str: Tokens as they are generated.
|
| 20 |
-
"""
|
| 21 |
-
# Construct a base URL that points to the model’s endpoint.
|
| 22 |
-
base_url = f"https://api-inference.huggingface.co/models/{model_name}"
|
| 23 |
-
token = os.getenv("HF_API_TOKEN")
|
| 24 |
-
client = AsyncInferenceClient(base_url=base_url, token=token)
|
| 25 |
-
|
| 26 |
-
stream = await client.chat.completions.create(
|
| 27 |
-
messages=messages,
|
| 28 |
-
stream=True,
|
| 29 |
-
max_tokens=max_tokens,
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
async for chunk in stream:
|
| 33 |
-
# Each chunk is expected to have a structure where the generated text is in:
|
| 34 |
-
# chunk.choices[0].delta.content
|
| 35 |
-
yield chunk.choices[0].delta.content or ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|