| | import asyncio |
| | import json |
| | from fastapi import APIRouter, Depends |
| | from httpx import AsyncClient |
| | from jinja2 import Environment, TemplateNotFound |
| | from litellm.router import Router |
| | from dependencies import INSIGHT_FINDER_BASE_URL, get_http_client, get_llm_router, get_prompt_templates |
| | from typing import Awaitable, Callable, TypeVar |
| | from schemas import _RefinedSolutionModel, _BootstrappedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, InsightFinderConstraintsList, PriorArtSearchRequest, PriorArtSearchResponse, ReqGroupingCategory, ReqGroupingRequest, ReqGroupingResponse, ReqSearchLLMResponse, ReqSearchRequest, ReqSearchResponse, SolutionCriticism, SolutionModel, SolutionBootstrapResponse, SolutionBootstrapRequest, TechnologyData |
| |
|
| | |
| | router = APIRouter(tags=["solution generation and critique"]) |
| |
|
| |
|
| | |
| | T = TypeVar("T") |
| | A = TypeVar("A") |
| |
|
| |
|
| | async def retry_until( |
| | func: Callable[[A], Awaitable[T]], |
| | arg: A, |
| | predicate: Callable[[T], bool], |
| | max_retries: int, |
| | ) -> T: |
| | """Retries the given async function until the passed in validation predicate returns true.""" |
| | last_value = await func(arg) |
| | for _ in range(max_retries): |
| | if predicate(last_value): |
| | return last_value |
| | last_value = await func(arg) |
| | return last_value |
| |
|
| | |
| |
|
| |
|
| | @router.post("/bootstrap_solutions") |
| | async def bootstrap_solutions(req: SolutionBootstrapRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router), http_client: AsyncClient = Depends(get_http_client)) -> SolutionBootstrapResponse: |
| | """ |
| | Boostraps a solution for each of the passed in requirements categories using Insight Finder's API. |
| | """ |
| |
|
| | async def _bootstrap_solution_inner(cat: ReqGroupingCategory): |
| | |
| | fmt_completion = await llm_router.acompletion("gemini-v2", messages=[ |
| | { |
| | "role": "user", |
| | "content": await prompt_env.get_template("format_requirements.txt").render_async(**{ |
| | "category": cat.model_dump(), |
| | "response_schema": InsightFinderConstraintsList.model_json_schema() |
| | }) |
| | }], response_format=InsightFinderConstraintsList) |
| |
|
| | fmt_model = InsightFinderConstraintsList.model_validate_json( |
| | fmt_completion.choices[0].message.content) |
| |
|
| | |
| | formatted_constraints = {'constraints': { |
| | cons.title: cons.description for cons in fmt_model.constraints}} |
| |
|
| | |
| | technologies_req = await http_client.post(INSIGHT_FINDER_BASE_URL + "process-constraints", content=json.dumps(formatted_constraints)) |
| | technologies = TechnologyData.model_validate(technologies_req.json()) |
| |
|
| | |
| |
|
| | format_solution = await llm_router.acompletion("gemini-v2", messages=[{ |
| | "role": "user", |
| | "content": await prompt_env.get_template("bootstrap_solution.txt").render_async(**{ |
| | "category": cat.model_dump(), |
| | "technologies": technologies.model_dump()["technologies"], |
| | "user_constraints": req.user_constraints, |
| | "response_schema": _BootstrappedSolutionModel.model_json_schema() |
| | })} |
| | ], response_format=_BootstrappedSolutionModel) |
| |
|
| | format_solution_model = _BootstrappedSolutionModel.model_validate_json( |
| | format_solution.choices[0].message.content) |
| |
|
| | final_solution = SolutionModel( |
| | context="", |
| | requirements=[ |
| | cat.requirements[i].requirement for i in format_solution_model.requirement_ids |
| | ], |
| | problem_description=format_solution_model.problem_description, |
| | solution_description=format_solution_model.solution_description, |
| | references=[], |
| | category_id=cat.id, |
| | ) |
| |
|
| | |
| |
|
| | return final_solution |
| |
|
| | tasks = await asyncio.gather(*[_bootstrap_solution_inner(cat) for cat in req.categories], return_exceptions=True) |
| | final_solutions = [sol for sol in tasks if not isinstance(sol, Exception)] |
| |
|
| | return SolutionBootstrapResponse(solutions=final_solutions) |
| |
|
| |
|
| | @router.post("/criticize_solution", response_model=CritiqueResponse) |
| | async def criticize_solution(params: CriticizeSolutionsRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> CritiqueResponse: |
| | """Criticize the challenges, weaknesses and limitations of the provided solutions.""" |
| |
|
| | async def __criticize_single(solution: SolutionModel): |
| | req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{ |
| | "solutions": [solution.model_dump()], |
| | "response_schema": _SolutionCriticismOutput.model_json_schema() |
| | }) |
| |
|
| | req_completion = await llm_router.acompletion( |
| | model="gemini-v2", |
| | messages=[{"role": "user", "content": req_prompt}], |
| | response_format=_SolutionCriticismOutput |
| | ) |
| |
|
| | criticism_out = _SolutionCriticismOutput.model_validate_json( |
| | req_completion.choices[0].message.content |
| | ) |
| |
|
| | return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0]) |
| |
|
| | critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False) |
| | return CritiqueResponse(critiques=critiques) |
| |
|
| |
|
| | |
| |
|
| | @router.post("/refine_solutions", response_model=SolutionBootstrapResponse) |
| | async def refine_solutions(params: CritiqueResponse, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> SolutionBootstrapResponse: |
| | """Refines the previously critiqued solutions.""" |
| |
|
| | async def __refine_solution(crit: SolutionCriticism): |
| | req_prompt = await prompt_env.get_template("refine_solution.txt").render_async(**{ |
| | "solution": crit.solution.model_dump(), |
| | "criticism": crit.criticism, |
| | "response_schema": _RefinedSolutionModel.model_json_schema(), |
| | }) |
| |
|
| | req_completion = await llm_router.acompletion(model="gemini-v2", messages=[ |
| | {"role": "user", "content": req_prompt} |
| | ], response_format=_RefinedSolutionModel) |
| |
|
| | req_model = _RefinedSolutionModel.model_validate_json( |
| | req_completion.choices[0].message.content) |
| |
|
| | |
| | refined_solution = crit.solution.model_copy(deep=True) |
| | refined_solution.problem_description = req_model.problem_description |
| | refined_solution.solution_description = req_model.solution_description |
| |
|
| | return refined_solution |
| |
|
| | refined_solutions = await asyncio.gather(*[__refine_solution(crit) for crit in params.critiques], return_exceptions=False) |
| |
|
| | return SolutionBootstrapResponse(solutions=refined_solutions) |
| |
|
| |
|
| | @router.post("/search_prior_art") |
| | async def search_prior_art(req: PriorArtSearchRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> PriorArtSearchResponse: |
| | """Performs a comprehensive prior art search / FTO search against the provided topics for a drafted solution""" |
| |
|
| | sema = asyncio.Semaphore(4) |
| |
|
| | async def __search_topic(topic: str) -> str: |
| | search_prompt = await prompt_env.get_template("search/search_topic.txt").render_async(**{ |
| | "topic": topic |
| | }) |
| |
|
| | try: |
| | await sema.acquire() |
| |
|
| | search_completion = await llm_router.acompletion(model="gemini-v2", messages=[ |
| | {"role": "user", "content": search_prompt} |
| | ], temperature=0.3, tools=[{"googleSearch": {}}]) |
| |
|
| | return {"topic": topic, "content": search_completion.choices[0].message.content} |
| | finally: |
| | sema.release() |
| |
|
| | |
| | topics = await asyncio.gather(*[__search_topic(top) for top in req.topics], return_exceptions=False) |
| |
|
| | consolidation_prompt = await prompt_env.get_template("search/build_final_report.txt").render_async(**{ |
| | "searches": topics |
| | }) |
| |
|
| | |
| | consolidation_completion = await llm_router.acompletion(model="gemini-v2", messages=[ |
| | {"role": "user", "content": consolidation_prompt} |
| | ], temperature=0.5) |
| |
|
| | return PriorArtSearchResponse(content=consolidation_completion.choices[0].message.content, topic_contents=topics) |
| |
|