File size: 3,177 Bytes
a132742 d0c0916 b687311 a132742 3fbbeb9 cf5b4fc a132742 cf5b4fc a132742 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import asyncio
from typing import AsyncGenerator
from config.llm_config import llms
from config.chain_configs import chain_configs
from app.helpers.study_text_standardizer import standardize_studytext
async def run_learning_objectives_generator(
user_input_text: str,
model_choice_1: str,
model_choice_2: str,
text_format: str
) -> AsyncGenerator:
"""
Orchestrates the entire pipeline:
1) Standardize the study text
2) Generate (2 prompts × 2 LLMs) => 4 partial results
3) Sanitize each partial result
4) Yield partial updates in real time as each track completes
"""
# 1) Standardize the text once
config = chain_configs["learning_objectives"] # you define this in chain_configs.py
standardized = await standardize_studytext(
user_input_text, text_format,
config["template_standardize"],
config["llm_standardize"]
)
# Prepare references for the generation prompts:
prompt_a = config["template_gen_prompt_a"]
prompt_b = config["template_gen_prompt_b"]
sanitize_prompt = config["template_sanitize"]
# pick the LLMs from user choices (with fallback to config)
llm_a = llms.get(model_choice_1, config["default_llm_a"])
llm_b = llms.get(model_choice_2, config["default_llm_b"])
llm_sanitize = llms.get(config["llm_sanitize"])
# We will store the final sanitized results in an array of 4 strings
# (2 prompts × 2 LLMs)
partial_results = ["", "", "", ""]
# We'll define a short async helper for each track:
# 'track_index' is 0..3 so we know which of the 4 textboxes to fill
# 'gen_prompt' is either prompt_a or prompt_b
# 'gen_llm' is either llm_a or llm_b
async def run_track(track_index: int, gen_prompt, gen_llm):
# Step: generate
gen_msg = await gen_prompt.aformat_prompt(standardized_text=standardized)
gen_resp = await gen_llm.ainvoke(gen_msg.to_messages())
generation_output = getattr(gen_resp, "content", gen_resp)
# Step: sanitize
sanitize_msg = await sanitize_prompt.aformat_prompt(raw_output=generation_output)
sanitize_resp = await llm_sanitize.ainvoke(sanitize_msg.to_messages()) # or use a separate LLM for sanitization
sanitized_output = getattr(sanitize_resp, "content", sanitize_resp)
return (track_index, sanitized_output)
# Build the 4 tasks:
# - track 0 => prompt A, LLM 1
# - track 1 => prompt B, LLM 1
# - track 2 => prompt A, LLM 2
# - track 3 => prompt B, LLM 2
tasks = [
run_track(0, prompt_a, llm_a),
run_track(1, prompt_b, llm_a),
run_track(2, prompt_a, llm_b),
run_track(3, prompt_b, llm_b),
]
# We'll run them in parallel and yield updates as each finishes
done_count = 0
for coro in asyncio.as_completed(tasks):
track_index, final_text = await coro
partial_results[track_index] = final_text
done_count += 1
# yield partial update
# We yield a tuple with the 4 track results.
# The UI will map each item to the correct textbox.
yield tuple(partial_results) + (standardized, )
|