Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from gradio import ChatMessage | |
| from langchain_core.messages import BaseMessage, HumanMessage | |
| import time | |
| from ml_pipeline_workflow import ml_app as workflow | |
| def format_namespace(namespace): | |
| """๋ค์์คํ์ด์ค์์ ๊ทธ๋ํ ์ด๋ฆ ์ถ์ถ""" | |
| return namespace[-1].split(":")[0] if len(namespace) > 0 else "root graph" | |
| def handle_like(data: gr.LikeData): | |
| """Like/dislike ์ด๋ฒคํธ ํธ๋ค๋ฌ""" | |
| if data.liked: | |
| print(f"๐ Upvoted: {data.value}") | |
| else: | |
| print(f"๐ Downvoted: {data.value}") | |
| def generate_response(message, history): | |
| """ML Pipeline ์ํฌํ๋ก์ฐ์ ๋ํ nested thought ์๋ต ์์ฑ""" | |
| inputs = { | |
| "messages": [HumanMessage(content=message)], | |
| } | |
| response = [] | |
| message_id_counter = [0] # Use list for mutable counter | |
| agent_header_ids = {} # Map agent names to their header message IDs | |
| current_namespace = None # Track current namespace to detect changes | |
| for namespace, chunk in workflow.stream( | |
| inputs, | |
| stream_mode="updates", | |
| subgraphs=True | |
| ): | |
| start_time = time.time() | |
| for node_name, node_chunk in chunk.items(): | |
| formatted_namespace = format_namespace(namespace) | |
| # Complete previous non-header message FIRST (before namespace change check) | |
| if len(response) > 0 and response[-1].metadata.get("status") == "pending": | |
| # Only complete if it's not an agent header | |
| if response[-1].metadata.get("id") not in agent_header_ids.values(): | |
| prev_msg = response[-1] | |
| prev_msg.metadata["status"] = "done" | |
| if "start_time" in prev_msg.metadata: | |
| prev_msg.metadata["duration"] = time.time() - prev_msg.metadata["start_time"] | |
| # Update parent agent header to show completion | |
| prev_parent_id = prev_msg.metadata.get("parent_id") | |
| if prev_parent_id: | |
| prev_title = prev_msg.metadata.get("title", "unknown") | |
| for msg in response: | |
| if msg.metadata.get("id") == prev_parent_id and msg.metadata.get("is_agent_header"): | |
| # Add completion log | |
| if not msg.content or "๐" in msg.content: | |
| msg.content = f"โ {prev_title}" | |
| else: | |
| msg.content += f"\nโ {prev_title}" | |
| break | |
| yield response # Yield to show completion | |
| # If namespace changed, complete the previous agent header | |
| if current_namespace and current_namespace != formatted_namespace: | |
| if current_namespace in agent_header_ids: | |
| header_id = agent_header_ids[current_namespace] | |
| # Find and complete the agent header | |
| for msg in response: | |
| if msg.metadata.get("id") == header_id and msg.metadata.get("status") == "pending": | |
| msg.metadata["status"] = "done" | |
| if "start_time" in msg.metadata: | |
| msg.metadata["duration"] = time.time() - msg.metadata["start_time"] | |
| yield response | |
| break | |
| current_namespace = formatted_namespace | |
| # If this is a subgraph node, ensure parent header exists | |
| parent_id = None | |
| if formatted_namespace != "root graph": | |
| # This node is inside an agent subgraph | |
| if formatted_namespace not in agent_header_ids: | |
| # Create agent header message first | |
| message_id_counter[0] += 1 | |
| header_id = message_id_counter[0] | |
| agent_header_ids[formatted_namespace] = header_id | |
| agent_emojis = { | |
| "supervisor": "๐", | |
| "data_extraction_expert": "๐", | |
| "pretraining_expert": "๐ฌ", | |
| "finetuning_expert": "๐ฏ", | |
| "evaluation_expert": "๐" | |
| } | |
| emoji = agent_emojis.get(formatted_namespace, "๐ง ") | |
| header_message = ChatMessage( | |
| content="", | |
| metadata={ | |
| "title": f"{emoji} {formatted_namespace}", | |
| "id": header_id, | |
| "status": "pending", | |
| "start_time": time.time(), | |
| "is_agent_header": True # Mark as agent header | |
| } | |
| ) | |
| response.append(header_message) | |
| yield response | |
| parent_id = agent_header_ids[formatted_namespace] | |
| # Create current node message | |
| message_id_counter[0] += 1 | |
| current_id = message_id_counter[0] | |
| # Create title | |
| agent_names = ["data_extraction_expert", "pretraining_expert", | |
| "finetuning_expert", "evaluation_expert"] | |
| if formatted_namespace == "root graph": | |
| if node_name == "supervisor": | |
| title = f"๐ {node_name}" | |
| elif node_name in agent_names: | |
| # Skip - we'll handle these when we see their subgraph | |
| continue | |
| else: | |
| title = f"๐ง {node_name}" | |
| else: | |
| # This is inside an agent | |
| title = f"โ๏ธ {node_name}" | |
| # Create node message | |
| node_message = ChatMessage( | |
| content="", | |
| metadata={ | |
| "title": title, | |
| "id": current_id, | |
| "status": "pending", | |
| "start_time": time.time() | |
| } | |
| ) | |
| if parent_id: | |
| node_message.metadata["parent_id"] = parent_id | |
| response.append(node_message) | |
| yield response | |
| # Process node content | |
| out_str = [] | |
| if isinstance(node_chunk, dict): | |
| for k, v in node_chunk.items(): | |
| if isinstance(v, BaseMessage): | |
| out_str.append(v.pretty_repr()) | |
| elif isinstance(v, list): | |
| for list_item in v: | |
| if isinstance(list_item, BaseMessage): | |
| out_str.append(list_item.pretty_repr()) | |
| else: | |
| out_str.append(str(list_item)) | |
| else: | |
| out_str.append(f"{k}:\n{v}") | |
| response[-1].content = "\n".join(out_str) | |
| # Update parent agent header with current activity | |
| if parent_id: | |
| for msg in response: | |
| if msg.metadata.get("id") == parent_id and msg.metadata.get("is_agent_header"): | |
| # Update agent header content with summary of current activity | |
| activity_summary = f"๐ Working on: {node_name}" | |
| if len(out_str) > 0: | |
| # Add a brief preview of the content | |
| preview = out_str[0][:100].replace('\n', ' ') | |
| activity_summary += f"\n๐ {preview}..." | |
| msg.content = activity_summary | |
| break | |
| yield response | |
| # Keep the node as pending - it will be completed when the next node starts | |
| # Just yield to show current state | |
| yield response | |
| # Complete any remaining pending messages (both agent headers and sub nodes) | |
| for msg in response: | |
| if msg.metadata.get("status") == "pending": | |
| msg.metadata["status"] = "done" | |
| if "start_time" in msg.metadata: | |
| msg.metadata["duration"] = time.time() - msg.metadata["start_time"] | |
| # Add final response (without metadata so it displays as regular message) | |
| if node_chunk and isinstance(node_chunk, dict) and 'messages' in node_chunk: | |
| final_message = node_chunk['messages'][-1].content | |
| response.append(ChatMessage(content=final_message)) | |
| yield response | |
| # Create interface with Blocks for like functionality | |
| with gr.Blocks() as demo: | |
| chatbot = gr.Chatbot( | |
| type="messages", | |
| show_copy_button=True, | |
| show_copy_all_button=True, | |
| show_share_button=True | |
| ) | |
| chatbot.like(handle_like, None, None) | |
| gr.ChatInterface( | |
| generate_response, | |
| type="messages", | |
| chatbot=chatbot, | |
| title="๐ฌ ML Pipeline Automation", | |
| description="๋ฐ์ดํฐ ์ถ์ถ, ์ฌ์ ํ์ต, ํ์ธํ๋, ํ๊ฐ ๋จ๊ณ๋ฅผ ๊ฑฐ์น๋ ์์ ํ ML ํ์ดํ๋ผ์ธ์ ๋๋ค. ๊ฐ ์ ๋ฌธ๊ฐ ์์ด์ ํธ์ ์์ ๊ณผ์ ์ ๋จ๊ณ๋ณ๋ก ํ์ธํ ์ ์์ต๋๋ค.", | |
| examples=[ | |
| # ์ ์ฒด ํ์ดํ๋ผ์ธ ์์ | |
| "user_events ํ ์ด๋ธ์์ 2024-01-01๋ถํฐ 2024-12-31๊น์ง ์ด๋ฒคํธ ๋ฐ์ดํฐ๋ฅผ ์ถ์ถํ๊ณ , ๋ชจ๋ธ์ ์ฌ์ ํ์ตํ ํ 5๊ฐ ํด๋์ค ๋ถ๋ฅ ๋ชจ๋ธ์ ํ์ตํ๊ณ ํ๊ฐํด์ค", | |
| "transaction_data ํ ์ด๋ธ์์ ์ต๊ทผ 6๊ฐ์ ๋ฐ์ดํฐ๋ก ์ด์๊ฑฐ๋ ํ์ง ๋ชจ๋ธ์ ์ฒ์๋ถํฐ ๋๊น์ง ๋ง๋ค์ด์ค", | |
| # ์กฐํฉ ์์ (2-3๋จ๊ณ) | |
| "customer_logs ํ ์ด๋ธ์์ 2024๋ ๋ฐ์ดํฐ๋ฅผ ์ถ์ถํ๊ณ GPT-2 ๋ชจ๋ธ ์ฌ์ ํ์ต๊น์ง๋ง ํด์ค", | |
| "์ค๋น๋ ๋ฐ์ดํฐ(/data/corpus.txt)๋ก ๋ชจ๋ธ์ ์ฌ์ ํ์ตํ๊ณ , ๊ทธ ๋ชจ๋ธ๋ก ๊ฐ์ฑ๋ถ์ 3ํด๋์ค ๋ถ๋ฅ ํ์ธํ๋๊น์ง ์งํํด์ค", | |
| "์ฌ์ ํ์ต๋ ๋ชจ๋ธ(/models/pretrained/gpt2)๋ก spam detection 2ํด๋์ค ๋ถ๋ฅ ๋ชจ๋ธ ํ์ตํ๊ณ ์ฑ๋ฅ ํ๊ฐํด์ค", | |
| "product_reviews ํ ์ด๋ธ์์ 2024๋ ๋ฆฌ๋ทฐ ๋ฐ์ดํฐ ์ถ์ถํ๊ณ ๋ฐ๋ก ๋ณ์ ์์ธก 5ํด๋์ค ๋ถ๋ฅ ๋ชจ๋ธ ํ์ธํ๋ํด์ค", | |
| # ๋จ์ผ ๊ธฐ๋ฅ ์์ | |
| "์ค๋น๋ ๋ฐ์ดํฐ(/data/pretraining/corpus.txt)๋ก GPT-2 ๋ชจ๋ธ์ 3 ์ํฌํฌ ์ฌ์ ํ์ต๋ง ํด์ค", | |
| "์ฌ์ ํ์ต๋ ๋ชจ๋ธ(/models/pretrained/checkpoint)์ผ๋ก ๊ฐ์ฑ๋ถ์ 3ํด๋์ค ๋ถ๋ฅ ํ์ธํ๋๋ง ์งํํด์ค", | |
| "ํ์ต๋ ๋ชจ๋ธ(/models/finetuned/model)์ ํ ์คํธ ๋ฐ์ดํฐ(/data/test.jsonl)๋ก ํ๊ฐ๋ง ํด์ค" | |
| ], | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) | |