from llama_index.core.workflow import ( StartEvent, StopEvent, Workflow, step, Event, Context, ) import asyncio import nest_asyncio from llama_index.llms.groq import Groq from llama_index.utils.workflow import draw_all_possible_flows from IPython.display import display, HTML from dotenv import load_dotenv from helper import extract_html_content from pathlib import Path import os nest_asyncio.apply() load_dotenv() GROQ_API_KEY = os.getenv("GROQ_API_KEY") global_llm = Groq(api_key=GROQ_API_KEY, model="llama3-70b-8192") class FirstEvent(Event): first_output: str class SecondEvent(Event): second_output: str response: str class ProgressEvent(Event): msg: str class MyWorkflow(Workflow): @step async def step_one(self, ctx: Context, ev: StartEvent) -> FirstEvent: ctx.write_event_to_stream(ProgressEvent(msg="Step one is happening")) return FirstEvent(first_output="First step complete.") @step async def step_two(self, ctx: Context, ev: FirstEvent) -> SecondEvent: llm = global_llm generator = await llm.astream_complete( "Please give me the first 3 paragraphs of Moby Dick, a book in the public domain." ) async for response in generator: # Allow the workflow to stream this piece of response ctx.write_event_to_stream(ProgressEvent(msg=response.delta)) return SecondEvent( second_output="Second step complete, full response attached", response=str(response), ) @step async def step_three(self, ctx: Context, ev: SecondEvent) -> StopEvent: ctx.write_event_to_stream(ProgressEvent(msg="Step three is happening")) return StopEvent(result="Workflow complete.") async def main(): w = MyWorkflow(timeout=30, verbose=True) handler = w.run(first_input="Start the workflow.") async for ev in handler.stream_events(): if isinstance(ev, ProgressEvent): print(ev.msg) final_result = await handler print("Final result", final_result) workflow_file = Path(__file__).parent / "workflows" / "streaming_workflow.html" draw_all_possible_flows(w, filename=str(workflow_file)) html_content = extract_html_content(str(workflow_file)) display(HTML(html_content), metadata=dict(isolated=True)) if __name__ == "__main__": asyncio.run(main())