Yoon-gu Hwang Claude commited on
Commit
a8500b3
·
1 Parent(s): c9a0817

Add ML pipeline automation system with tabbed interface

Browse files

- Create tabbed interface with Basic Chatbot and ML Pipeline tabs
- Add tab_basic_chatbot.py for existing multi-agent chatbot
- Add tab_ml_pipeline.py for ML pipeline automation
- Add ml_workflow.py with 4 specialized agents (Korean prompts):
- data_extraction_expert: SQL and RDB operations
- pretraining_expert: Data preparation and model pretraining
- finetuning_expert: Finetuning data creation and classification training
- evaluation_expert: Model evaluation with precision, recall, F1, accuracy
- Add tools.py with realistic ML pipeline tool functions:
- extract_events_from_rdb, prepare_pretraining_data
- pretrain_model, create_finetuning_data
- train_classification_model, evaluate_model
- Update README.md with new features and project structure
- Simplify app.py to use TabbedInterface

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (6) hide show
  1. README.md +16 -3
  2. app.py +7 -86
  3. ml_workflow.py +87 -0
  4. tab_basic_chatbot.py +85 -0
  5. tab_ml_pipeline.py +88 -0
  6. tools.py +236 -0
README.md CHANGED
@@ -11,16 +11,25 @@ pinned: false
11
 
12
  # LangGraph UI
13
 
14
- A Gradio-based chat interface for LangGraph supervisor workflow with nested agent visualization.
15
 
16
  ## Features
17
 
 
18
  - 🤖 Multi-agent system with supervisor workflow
19
  - 🔍 Research agent with web search capabilities
20
  - 🧮 Math agent for calculations
21
  - 💬 Interactive chat interface with nested thoughts visualization
22
  - 🎨 Real-time streaming of agent execution steps
23
 
 
 
 
 
 
 
 
 
24
  ## Setup
25
 
26
  ### Local Development
@@ -55,8 +64,12 @@ To deploy on HuggingFace Spaces, you need to set the `OPENAI_API_KEY` as a secre
55
 
56
  ## Project Structure
57
 
58
- - `app.py` - Gradio chat interface
59
- - `graph.py` - LangGraph supervisor workflow and agents
 
 
 
 
60
  - `pyproject.toml` - Project dependencies managed by uv
61
  - `.env` - Environment variables (not tracked in git)
62
 
 
11
 
12
  # LangGraph UI
13
 
14
+ A Gradio-based chat interface for LangGraph supervisor workflow with nested agent visualization and ML pipeline automation.
15
 
16
  ## Features
17
 
18
+ ### Basic Chatbot
19
  - 🤖 Multi-agent system with supervisor workflow
20
  - 🔍 Research agent with web search capabilities
21
  - 🧮 Math agent for calculations
22
  - 💬 Interactive chat interface with nested thoughts visualization
23
  - 🎨 Real-time streaming of agent execution steps
24
 
25
+ ### ML Pipeline Automation
26
+ - 📊 Data extraction from RDB tables using SQL
27
+ - 🔤 Language model pretraining with tokenization
28
+ - 🎯 Classification model finetuning
29
+ - 📈 Comprehensive model evaluation (Precision, Recall, F1-score, Accuracy)
30
+ - 🤝 4 specialized agents coordinated by supervisor
31
+ - 🇰🇷 Korean language support
32
+
33
  ## Setup
34
 
35
  ### Local Development
 
64
 
65
  ## Project Structure
66
 
67
+ - `app.py` - Main application with tabbed interface
68
+ - `tab_basic_chatbot.py` - Basic chatbot with research and math agents
69
+ - `tab_ml_pipeline.py` - ML pipeline automation chatbot
70
+ - `graph.py` - Basic LangGraph supervisor workflow
71
+ - `ml_workflow.py` - ML pipeline supervisor workflow with 4 specialized agents
72
+ - `tools.py` - ML pipeline tools (data extraction, pretraining, finetuning, evaluation)
73
  - `pyproject.toml` - Project dependencies managed by uv
74
  - `.env` - Environment variables (not tracked in git)
75
 
app.py CHANGED
@@ -1,91 +1,12 @@
1
  import gradio as gr
2
- import time
3
- from gradio import ChatMessage
4
- from langchain_core.runnables import RunnableConfig
5
- from langchain_core.messages import BaseMessage, HumanMessage
6
- from pprint import pprint
7
- from graph import app as workflow
8
 
9
- def format_namespace(namespace):
10
- return namespace[-1].split(":")[0] if len(namespace) > 0 else "root graph"
11
-
12
- def generate_response(message, history):
13
- inputs = {
14
- "messages": [HumanMessage(content=message)],
15
- }
16
- node_names = []
17
- response = []
18
- for namespace, chunk in workflow.stream(
19
- inputs,
20
- stream_mode="updates", subgraphs=True
21
- ):
22
- for node_name, node_chunk in chunk.items():
23
- # node_names가 비어있지 않은 경우에만 필터링
24
- if len(node_names) > 0 and node_name not in node_names:
25
- continue
26
-
27
- if len(response) > 0:
28
- response[-1].metadata["status"] = "done"
29
- # print("\n" + "=" * 50)
30
- msg = []
31
- formatted_namespace = format_namespace(namespace)
32
- if formatted_namespace == "root graph":
33
- print(f"🔄 Node: \033[1;36m{node_name}\033[0m 🔄")
34
- meta_title = f"🤔 `{node_name}`"
35
- else:
36
- print(
37
- f"🔄 Node: \033[1;36m{node_name}\033[0m in [\033[1;33m{formatted_namespace}\033[0m] 🔄"
38
- )
39
- meta_title = f"🤔 `{node_name}` in `{formatted_namespace}`"
40
-
41
- response.append(ChatMessage(content="", metadata={"title": meta_title, "status": "pending"}))
42
- yield response
43
- print("- " * 25)
44
-
45
- # 노드의 청크 데이터 출력
46
- out_str = []
47
- if isinstance(node_chunk, dict):
48
- for k, v in node_chunk.items():
49
- if isinstance(v, BaseMessage):
50
- v.pretty_print()
51
- out_str.append(v.pretty_repr())
52
- elif isinstance(v, list):
53
- for list_item in v:
54
- if isinstance(list_item, BaseMessage):
55
- list_item.pretty_print()
56
- out_str.append(list_item.pretty_repr())
57
- else:
58
- out_str.append(list_item)
59
- print(list_item)
60
- elif isinstance(v, dict):
61
- for node_chunk_key, node_chunk_value in node_chunk.items():
62
- out_str.append(f"{node_chunk_key}:\n{node_chunk_value}")
63
- print(f"{node_chunk_key}:\n{node_chunk_value}")
64
- else:
65
- out_str.append(f"{k}:\n{v}")
66
- print(f"\033[1;32m{k}\033[0m:\n{v}")
67
- response[-1].content = "\n".join(out_str)
68
- yield response
69
- else:
70
- if node_chunk is not None:
71
- for item in node_chunk:
72
- out_str.append(item)
73
- print(item)
74
- response[-1].content = "\n".join(out_str)
75
- yield response
76
- yield response
77
- print("=" * 50)
78
- response[-1].metadata["status"] = "done"
79
- response.append(ChatMessage(content=node_chunk['messages'][-1].content))
80
- yield response
81
-
82
- demo = gr.ChatInterface(
83
- generate_response,
84
- type="messages",
85
- title="Nested Thoughts Chat Interface",
86
- examples=["2024년의 the FAANG companies 총 근로자규모에 대한 분석을 한국어로 부탁해!"],
87
- cache_examples=False
88
  )
89
 
90
  if __name__ == "__main__":
91
- demo.launch(ssr_mode=False)
 
1
  import gradio as gr
2
+ from tab_basic_chatbot import demo as basic_chatbot
3
+ from tab_ml_pipeline import demo as ml_pipeline
 
 
 
 
4
 
5
+ demo = gr.TabbedInterface(
6
+ [basic_chatbot, ml_pipeline],
7
+ ["Basic Chatbot", "ML Pipeline"],
8
+ title="Multi-Agent Systems"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  )
10
 
11
  if __name__ == "__main__":
12
+ demo.launch(ssr_mode=False)
ml_workflow.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from langchain_openai import ChatOpenAI
3
+ from langgraph_supervisor import create_supervisor
4
+ from langgraph.prebuilt import create_react_agent
5
+ from tools import (
6
+ extract_events_from_rdb,
7
+ prepare_pretraining_data,
8
+ pretrain_model,
9
+ create_finetuning_data,
10
+ train_classification_model,
11
+ evaluate_model
12
+ )
13
+
14
+ load_dotenv()
15
+
16
+ model = ChatOpenAI(model="gpt-4o")
17
+
18
+ data_extraction_agent = create_react_agent(
19
+ model=model,
20
+ tools=[extract_events_from_rdb],
21
+ name="data_extraction_expert",
22
+ prompt=(
23
+ "당신은 SQL과 RDB 작업에 특화된 데이터 추출 전문가입니다. "
24
+ "데이터베이스 테이블에서 이벤트 레코드를 추출하고 텍스트 형식으로 변환하는 역할을 합니다. "
25
+ "테이블 이름, 날짜 범위, 이벤트 타입에 대한 명확한 정보를 제공해야 합니다. "
26
+ "레코드 수와 파일 크기를 포함한 추출 통계를 보고하세요."
27
+ )
28
+ )
29
+
30
+ pretraining_agent = create_react_agent(
31
+ model=model,
32
+ tools=[prepare_pretraining_data, pretrain_model],
33
+ name="pretraining_expert",
34
+ prompt=(
35
+ "당신은 언어모델 사전학습 전문가입니다. "
36
+ "토큰화된 데이터를 준비하고 모델을 처음부터 학습시키는 책임을 맡고 있습니다. "
37
+ "Loss와 Perplexity 같은 학습 지표를 모니터링하세요. "
38
+ "데이터 준비 및 모델 학습 진행 상황에 대한 자세한 통계를 보고하세요. "
39
+ "한 번에 하나의 도구만 사용하세요."
40
+ )
41
+ )
42
+
43
+ finetuning_agent = create_react_agent(
44
+ model=model,
45
+ tools=[create_finetuning_data, train_classification_model],
46
+ name="finetuning_expert",
47
+ prompt=(
48
+ "당신은 분류 작업에 특화된 파인튜닝 전문가입니다. "
49
+ "고품질의 파인튜닝 데이터셋을 만들고 분류 모델을 학습시키는 역할을 합니다. "
50
+ "적절한 데이터 분할과 클래스 분포를 보장하세요. "
51
+ "파인튜닝 과정 전반에 걸쳐 학습 및 검증 지표를 모니터링하세요. "
52
+ "한 번에 하나의 도구만 사용하세요."
53
+ )
54
+ )
55
+
56
+ evaluation_agent = create_react_agent(
57
+ model=model,
58
+ tools=[evaluate_model],
59
+ name="evaluation_expert",
60
+ prompt=(
61
+ "당신은 분류 지표에 특화된 모델 평가 전문가입니다. "
62
+ "Precision, Recall, F1-score, Accuracy를 사용하여 학습된 모델을 철저히 평가하는 역할을 합니다. "
63
+ "클래스별 세부 지표와 전체 성능 통계를 제공하세요. "
64
+ "Confusion matrix를 분석하고 개선이 필요한 영역을 파악하세요."
65
+ )
66
+ )
67
+
68
+ workflow = create_supervisor(
69
+ [data_extraction_agent, pretraining_agent, finetuning_agent, evaluation_agent],
70
+ model=model,
71
+ prompt=(
72
+ "당신은 완전한 모델 개발 워크플로우를 관리하는 ML 파이프라인 감독자입니다. "
73
+ "팀은 4명의 전문가로 구성되어 있습니다:\n\n"
74
+ "1. data_extraction_expert: SQL을 사용하여 RDB 테이블에서 이벤트 데이터를 추출합니다\n"
75
+ "2. pretraining_expert: 데이터를 준비하고 언어모델을 사전학습시킵니다\n"
76
+ "3. finetuning_expert: 파인튜닝 데이터를 생성하고 분류 모델을 학습시킵니다\n"
77
+ "4. evaluation_expert: Precision, Recall 등의 지표로 모델을 평가합니다\n\n"
78
+ "작업 순서:\n"
79
+ "1. data_extraction_expert를 사용하여 RDB에서 이벤트 데이터를 추출합니다\n"
80
+ "2. pretraining_expert를 사용하여 데이터를 준비하고 모델을 사전학습시킵니다\n"
81
+ "3. finetuning_expert를 사용하여 파인튜닝 데이터를 만들고 분류 모델을 학습시킵니다\n"
82
+ "4. evaluation_expert를 사용하여 최종 모델을 평가합니다\n\n"
83
+ "팀을 조율하여 전체 ML 파이프라인을 효율적으로 완료하세요."
84
+ )
85
+ )
86
+
87
+ ml_app = workflow.compile()
tab_basic_chatbot.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio import ChatMessage
3
+ from langchain_core.messages import BaseMessage, HumanMessage
4
+ from graph import app as workflow
5
+
6
+
7
+ def format_namespace(namespace):
8
+ return namespace[-1].split(":")[0] if len(namespace) > 0 else "root graph"
9
+
10
+
11
+ def generate_response(message, history):
12
+ inputs = {
13
+ "messages": [HumanMessage(content=message)],
14
+ }
15
+ node_names = []
16
+ response = []
17
+ for namespace, chunk in workflow.stream(
18
+ inputs,
19
+ stream_mode="updates", subgraphs=True
20
+ ):
21
+ for node_name, node_chunk in chunk.items():
22
+ if len(node_names) > 0 and node_name not in node_names:
23
+ continue
24
+
25
+ if len(response) > 0:
26
+ response[-1].metadata["status"] = "done"
27
+ msg = []
28
+ formatted_namespace = format_namespace(namespace)
29
+ if formatted_namespace == "root graph":
30
+ print(f"🔄 Node: \033[1;36m{node_name}\033[0m 🔄")
31
+ meta_title = f"🤔 `{node_name}`"
32
+ else:
33
+ print(
34
+ f"🔄 Node: \033[1;36m{node_name}\033[0m in [\033[1;33m{formatted_namespace}\033[0m] 🔄"
35
+ )
36
+ meta_title = f"🤔 `{node_name}` in `{formatted_namespace}`"
37
+
38
+ response.append(ChatMessage(content="", metadata={"title": meta_title, "status": "pending"}))
39
+ yield response
40
+ print("- " * 25)
41
+
42
+ out_str = []
43
+ if isinstance(node_chunk, dict):
44
+ for k, v in node_chunk.items():
45
+ if isinstance(v, BaseMessage):
46
+ v.pretty_print()
47
+ out_str.append(v.pretty_repr())
48
+ elif isinstance(v, list):
49
+ for list_item in v:
50
+ if isinstance(list_item, BaseMessage):
51
+ list_item.pretty_print()
52
+ out_str.append(list_item.pretty_repr())
53
+ else:
54
+ out_str.append(list_item)
55
+ print(list_item)
56
+ elif isinstance(v, dict):
57
+ for node_chunk_key, node_chunk_value in node_chunk.items():
58
+ out_str.append(f"{node_chunk_key}:\n{node_chunk_value}")
59
+ print(f"{node_chunk_key}:\n{node_chunk_value}")
60
+ else:
61
+ out_str.append(f"{k}:\n{v}")
62
+ print(f"\033[1;32m{k}\033[0m:\n{v}")
63
+ response[-1].content = "\n".join(out_str)
64
+ yield response
65
+ else:
66
+ if node_chunk is not None:
67
+ for item in node_chunk:
68
+ out_str.append(item)
69
+ print(item)
70
+ response[-1].content = "\n".join(out_str)
71
+ yield response
72
+ yield response
73
+ print("=" * 50)
74
+ response[-1].metadata["status"] = "done"
75
+ response.append(ChatMessage(content=node_chunk['messages'][-1].content))
76
+ yield response
77
+
78
+
79
+ demo = gr.ChatInterface(
80
+ generate_response,
81
+ type="messages",
82
+ title="Basic Multi-Agent Chatbot",
83
+ examples=["2024년의 the FAANG companies 총 근로자규모에 대한 분석을 한국어로 부탁해!"],
84
+ cache_examples=False
85
+ )
tab_ml_pipeline.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio import ChatMessage
3
+ from langchain_core.messages import BaseMessage, HumanMessage
4
+ from ml_workflow import ml_app as workflow
5
+
6
+
7
+ def format_namespace(namespace):
8
+ return namespace[-1].split(":")[0] if len(namespace) > 0 else "root graph"
9
+
10
+
11
+ def generate_response(message, history):
12
+ inputs = {
13
+ "messages": [HumanMessage(content=message)],
14
+ }
15
+ node_names = []
16
+ response = []
17
+ for namespace, chunk in workflow.stream(
18
+ inputs,
19
+ stream_mode="updates", subgraphs=True
20
+ ):
21
+ for node_name, node_chunk in chunk.items():
22
+ if len(node_names) > 0 and node_name not in node_names:
23
+ continue
24
+
25
+ if len(response) > 0:
26
+ response[-1].metadata["status"] = "done"
27
+ msg = []
28
+ formatted_namespace = format_namespace(namespace)
29
+ if formatted_namespace == "root graph":
30
+ print(f"🔄 Node: \033[1;36m{node_name}\033[0m 🔄")
31
+ meta_title = f"🤔 `{node_name}`"
32
+ else:
33
+ print(
34
+ f"🔄 Node: \033[1;36m{node_name}\033[0m in [\033[1;33m{formatted_namespace}\033[0m] 🔄"
35
+ )
36
+ meta_title = f"🤔 `{node_name}` in `{formatted_namespace}`"
37
+
38
+ response.append(ChatMessage(content="", metadata={"title": meta_title, "status": "pending"}))
39
+ yield response
40
+ print("- " * 25)
41
+
42
+ out_str = []
43
+ if isinstance(node_chunk, dict):
44
+ for k, v in node_chunk.items():
45
+ if isinstance(v, BaseMessage):
46
+ v.pretty_print()
47
+ out_str.append(v.pretty_repr())
48
+ elif isinstance(v, list):
49
+ for list_item in v:
50
+ if isinstance(list_item, BaseMessage):
51
+ list_item.pretty_print()
52
+ out_str.append(list_item.pretty_repr())
53
+ else:
54
+ out_str.append(list_item)
55
+ print(list_item)
56
+ elif isinstance(v, dict):
57
+ for node_chunk_key, node_chunk_value in node_chunk.items():
58
+ out_str.append(f"{node_chunk_key}:\n{node_chunk_value}")
59
+ print(f"{node_chunk_key}:\n{node_chunk_value}")
60
+ else:
61
+ out_str.append(f"{k}:\n{v}")
62
+ print(f"\033[1;32m{k}\033[0m:\n{v}")
63
+ response[-1].content = "\n".join(out_str)
64
+ yield response
65
+ else:
66
+ if node_chunk is not None:
67
+ for item in node_chunk:
68
+ out_str.append(item)
69
+ print(item)
70
+ response[-1].content = "\n".join(out_str)
71
+ yield response
72
+ yield response
73
+ print("=" * 50)
74
+ response[-1].metadata["status"] = "done"
75
+ response.append(ChatMessage(content=node_chunk['messages'][-1].content))
76
+ yield response
77
+
78
+
79
+ demo = gr.ChatInterface(
80
+ generate_response,
81
+ type="messages",
82
+ title="ML Pipeline Automation System",
83
+ examples=[
84
+ "user_events 테이블에서 2024-01-01부터 2024-12-31까지 이벤트 데이터를 추출하고, 모델을 사전학습한 후 5개 클래스 분류 모델을 학습하고 평가해줘",
85
+ "Extract events from transaction_logs table for Q1 2024, pretrain a GPT2 model, create finetuning data for 3-class classification, and evaluate the results"
86
+ ],
87
+ cache_examples=False
88
+ )
tools.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import time
3
+
4
+
5
+ def extract_events_from_rdb(
6
+ table_name: str,
7
+ start_date: str,
8
+ end_date: str,
9
+ event_types: List[str] = None
10
+ ) -> Dict[str, Any]:
11
+ """
12
+ Extract event records from RDB table and convert to text format.
13
+
14
+ Args:
15
+ table_name: Name of the RDB table
16
+ start_date: Start date in YYYY-MM-DD format
17
+ end_date: End date in YYYY-MM-DD format
18
+ event_types: Optional list of event types to filter
19
+
20
+ Returns:
21
+ Dictionary containing extracted data statistics and file path
22
+ """
23
+ time.sleep(0.5)
24
+ return {
25
+ "status": "success",
26
+ "records_extracted": 125847,
27
+ "output_file": f"/data/events/{table_name}_{start_date}_{end_date}.txt",
28
+ "total_size_mb": 482.3,
29
+ "event_type_distribution": {
30
+ "user_action": 45230,
31
+ "system_event": 32145,
32
+ "error_log": 18472,
33
+ "transaction": 30000
34
+ },
35
+ "processing_time_seconds": 12.5
36
+ }
37
+
38
+
39
+ def prepare_pretraining_data(
40
+ input_file: str,
41
+ tokenizer: str = "gpt2",
42
+ max_length: int = 512,
43
+ min_length: int = 50
44
+ ) -> Dict[str, Any]:
45
+ """
46
+ Prepare text data for pretraining by tokenization and formatting.
47
+
48
+ Args:
49
+ input_file: Path to input text file
50
+ tokenizer: Tokenizer to use
51
+ max_length: Maximum sequence length
52
+ min_length: Minimum sequence length
53
+
54
+ Returns:
55
+ Dictionary containing prepared data statistics
56
+ """
57
+ time.sleep(0.5)
58
+ return {
59
+ "status": "success",
60
+ "output_file": "/data/pretraining/tokenized_data.bin",
61
+ "total_sequences": 89234,
62
+ "total_tokens": 45623890,
63
+ "avg_sequence_length": 511.2,
64
+ "vocab_size": 50257,
65
+ "processing_time_seconds": 34.2
66
+ }
67
+
68
+
69
+ def pretrain_model(
70
+ data_file: str,
71
+ model_architecture: str = "gpt2-small",
72
+ num_epochs: int = 3,
73
+ batch_size: int = 32,
74
+ learning_rate: float = 5e-5
75
+ ) -> Dict[str, Any]:
76
+ """
77
+ Pretrain language model on prepared data.
78
+
79
+ Args:
80
+ data_file: Path to tokenized data
81
+ model_architecture: Model architecture to use
82
+ num_epochs: Number of training epochs
83
+ batch_size: Batch size for training
84
+ learning_rate: Learning rate
85
+
86
+ Returns:
87
+ Dictionary containing training metrics and model path
88
+ """
89
+ time.sleep(0.5)
90
+ return {
91
+ "status": "success",
92
+ "model_path": "/models/pretrained/model_checkpoint_epoch3",
93
+ "final_loss": 2.341,
94
+ "perplexity": 10.39,
95
+ "training_time_hours": 4.5,
96
+ "total_steps": 8340,
97
+ "best_checkpoint": "checkpoint-7800",
98
+ "gpu_hours": 36.0,
99
+ "metrics": {
100
+ "epoch_1_loss": 3.245,
101
+ "epoch_2_loss": 2.789,
102
+ "epoch_3_loss": 2.341
103
+ }
104
+ }
105
+
106
+
107
+ def create_finetuning_data(
108
+ source_data: str,
109
+ task_type: str = "classification",
110
+ num_classes: int = 5,
111
+ train_ratio: float = 0.8,
112
+ augmentation: bool = True
113
+ ) -> Dict[str, Any]:
114
+ """
115
+ Create finetuning dataset for classification task.
116
+
117
+ Args:
118
+ source_data: Path to source data
119
+ task_type: Type of task (classification, regression, etc.)
120
+ num_classes: Number of classification classes
121
+ train_ratio: Ratio of training data
122
+ augmentation: Whether to apply data augmentation
123
+
124
+ Returns:
125
+ Dictionary containing dataset statistics and file paths
126
+ """
127
+ time.sleep(0.5)
128
+ return {
129
+ "status": "success",
130
+ "train_file": "/data/finetuning/train.jsonl",
131
+ "val_file": "/data/finetuning/val.jsonl",
132
+ "test_file": "/data/finetuning/test.jsonl",
133
+ "train_samples": 12456,
134
+ "val_samples": 3114,
135
+ "test_samples": 3114,
136
+ "class_distribution": {
137
+ "class_0": 2489,
138
+ "class_1": 3201,
139
+ "class_2": 2845,
140
+ "class_3": 2134,
141
+ "class_4": 1787
142
+ },
143
+ "augmentation_applied": True,
144
+ "processing_time_seconds": 8.3
145
+ }
146
+
147
+
148
+ def train_classification_model(
149
+ pretrained_model: str,
150
+ train_data: str,
151
+ val_data: str,
152
+ num_classes: int = 5,
153
+ num_epochs: int = 10,
154
+ batch_size: int = 16,
155
+ learning_rate: float = 2e-5
156
+ ) -> Dict[str, Any]:
157
+ """
158
+ Train classification model using finetuning data.
159
+
160
+ Args:
161
+ pretrained_model: Path to pretrained model
162
+ train_data: Path to training data
163
+ val_data: Path to validation data
164
+ num_classes: Number of classes
165
+ num_epochs: Number of training epochs
166
+ batch_size: Batch size
167
+ learning_rate: Learning rate
168
+
169
+ Returns:
170
+ Dictionary containing training results and model path
171
+ """
172
+ time.sleep(0.5)
173
+ return {
174
+ "status": "success",
175
+ "model_path": "/models/finetuned/classification_model",
176
+ "best_checkpoint": "checkpoint-epoch8",
177
+ "final_train_loss": 0.234,
178
+ "final_val_loss": 0.312,
179
+ "best_val_accuracy": 0.923,
180
+ "training_time_hours": 1.2,
181
+ "total_steps": 7785,
182
+ "early_stopping_epoch": 8,
183
+ "metrics_per_epoch": {
184
+ "epoch_1": {"train_loss": 0.892, "val_loss": 0.845, "val_acc": 0.712},
185
+ "epoch_5": {"train_loss": 0.345, "val_loss": 0.389, "val_acc": 0.887},
186
+ "epoch_8": {"train_loss": 0.234, "val_loss": 0.312, "val_acc": 0.923}
187
+ }
188
+ }
189
+
190
+
191
+ def evaluate_model(
192
+ model_path: str,
193
+ test_data: str,
194
+ metrics: List[str] = None
195
+ ) -> Dict[str, Any]:
196
+ """
197
+ Evaluate trained model on test data with comprehensive metrics.
198
+
199
+ Args:
200
+ model_path: Path to trained model
201
+ test_data: Path to test data
202
+ metrics: List of metrics to compute
203
+
204
+ Returns:
205
+ Dictionary containing evaluation metrics
206
+ """
207
+ time.sleep(0.5)
208
+ if metrics is None:
209
+ metrics = ["precision", "recall", "f1", "accuracy"]
210
+
211
+ return {
212
+ "status": "success",
213
+ "test_samples": 3114,
214
+ "overall_accuracy": 0.918,
215
+ "macro_precision": 0.912,
216
+ "macro_recall": 0.908,
217
+ "macro_f1": 0.910,
218
+ "weighted_precision": 0.916,
219
+ "weighted_recall": 0.918,
220
+ "weighted_f1": 0.917,
221
+ "per_class_metrics": {
222
+ "class_0": {"precision": 0.935, "recall": 0.921, "f1": 0.928, "support": 623},
223
+ "class_1": {"precision": 0.948, "recall": 0.952, "f1": 0.950, "support": 640},
224
+ "class_2": {"precision": 0.899, "recall": 0.887, "f1": 0.893, "support": 569},
225
+ "class_3": {"precision": 0.887, "recall": 0.901, "f1": 0.894, "support": 427},
226
+ "class_4": {"precision": 0.891, "recall": 0.879, "f1": 0.885, "support": 357}
227
+ },
228
+ "confusion_matrix": [
229
+ [574, 12, 18, 10, 9],
230
+ [8, 609, 11, 7, 5],
231
+ [15, 9, 505, 28, 12],
232
+ [11, 8, 22, 385, 1],
233
+ [14, 6, 18, 5, 314]
234
+ ],
235
+ "inference_time_ms": 1247.5
236
+ }