princhman commited on
Commit
73d131e
·
verified ·
1 Parent(s): b544f5d

update worker.py

Browse files
Files changed (1) hide show
  1. worker.py +125 -125
worker.py CHANGED
@@ -1,126 +1,126 @@
1
- #!/usr/bin/env python3
2
- import os
3
- import json
4
- import time
5
- import threading
6
- import multiprocessing
7
- from concurrent.futures import ThreadPoolExecutor
8
- import pika
9
- from typing import Tuple, Dict, Any
10
-
11
- from mineru_single import Processor
12
-
13
- class MessageProcessor:
14
- def __init__(self):
15
- self.processor = Processor()
16
-
17
- def process_message(self, body_bytes: bytes) -> Tuple[str, Dict[str, Any]]:
18
- """Process incoming message and return processed results"""
19
- body_str = body_bytes.decode("utf-8")
20
- data = json.loads(body_str)
21
-
22
- headers = data.get("headers", {})
23
- request_type = headers.get("request_type", "")
24
- request_id = headers.get("request_id", "")
25
- body = data.get("body", {})
26
-
27
- if request_type != "process_files":
28
- return "No processing done", data
29
-
30
- input_files = body.get("input_files", [])
31
- topics = body.get("topics", [])
32
-
33
- urls, file_key_map = self._extract_urls_and_keys(input_files)
34
- batch_results = self.processor.process_batch(urls)
35
- md_context = self._create_markdown_context(batch_results, file_key_map)
36
-
37
- final_json = self._create_response_json(request_id, input_files, topics, md_context)
38
- return json.dumps(final_json, ensure_ascii=False), final_json
39
-
40
- def _extract_urls_and_keys(self, input_files: list) -> Tuple[list, dict]:
41
- """Extract URLs and create file key mapping"""
42
- urls = []
43
- file_key_map = {}
44
- for f in input_files:
45
- key = f.get("key", "")
46
- url = f.get("url", "")
47
- urls.append(url)
48
- file_key_map[url] = key
49
- return urls, file_key_map
50
-
51
- def _create_markdown_context(self, batch_results: dict, file_key_map: dict) -> list:
52
- """Create markdown context from batch results"""
53
- md_context = []
54
- for url, md_content in batch_results.items():
55
- key = file_key_map.get(url, "")
56
- md_context.append({"key": key, "body": md_content})
57
- return md_context
58
-
59
- def _create_response_json(self, request_id: str, input_files: list,
60
- topics: list, md_context: list) -> dict:
61
- """Create the final response JSON"""
62
- return {
63
- "headers": {
64
- "request_type": "question_extraction_update_from_gpu_server",
65
- "request_id": request_id
66
- },
67
- "body": {
68
- "input_files": input_files,
69
- "topics": topics,
70
- "md_context": md_context
71
- }
72
- }
73
-
74
- class RabbitMQWorker:
75
- def __init__(self, num_workers: int = 1):
76
- self.num_workers = num_workers
77
- self.message_processor = MessageProcessor()
78
-
79
- def callback(self, ch, method, properties, body):
80
- """Handle incoming RabbitMQ messages"""
81
- thread_id = threading.current_thread().name
82
- headers = properties.headers or {}
83
-
84
- print(f"[Worker {thread_id}] Received message: {body}, headers: {headers}")
85
-
86
- if headers.get("process") == "topic_extraction":
87
- raw_text_outputs, parsed_json_outputs = self.message_processor.process_message(body)
88
- print(f"[Worker {thread_id}] Pipeline result:\n{raw_text_outputs}")
89
- else:
90
- print(f"[Worker {thread_id}] Unknown process, sleeping 10s.")
91
- time.sleep(10)
92
- print("[Worker] Done")
93
-
94
- def worker(self, channel):
95
- """Worker process to consume messages"""
96
- try:
97
- channel.start_consuming()
98
- except Exception as e:
99
- print(f"[Worker] Error: {e}")
100
-
101
- def connect_to_rabbitmq(self):
102
- """Establish connection to RabbitMQ"""
103
- rabbit_url = os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")
104
- connection = pika.BlockingConnection(pika.URLParameters(rabbit_url))
105
- channel = connection.channel()
106
-
107
- channel.queue_declare(queue="ml_server", durable=True)
108
- channel.basic_qos(prefetch_count=1)
109
- channel.basic_consume(
110
- queue="ml_server",
111
- on_message_callback=self.callback,
112
- auto_ack=True
113
- )
114
- return connection, channel
115
-
116
- def start(self):
117
- """Start the worker threads"""
118
- print(f"Starting {self.num_workers} workers")
119
- with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
120
- for _ in range(self.num_workers):
121
- connection, channel = self.connect_to_rabbitmq()
122
- executor.submit(self.worker, channel)
123
-
124
- def main():
125
- worker = RabbitMQWorker()
126
  worker.start()
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import json
4
+ import time
5
+ import threading
6
+ import multiprocessing
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ import pika
9
+ from typing import Tuple, Dict, Any
10
+
11
+ from mineru_single import Processor
12
+
13
+ class MessageProcessor:
14
+ def __init__(self):
15
+ self.processor = Processor()
16
+
17
+ def process_message(self, body_bytes: bytes) -> Tuple[str, Dict[str, Any]]:
18
+ """Process incoming message and return processed results"""
19
+ body_str = body_bytes.decode("utf-8")
20
+ data = json.loads(body_str)
21
+
22
+ headers = data.get("headers", {})
23
+ request_type = headers.get("request_type", "")
24
+ request_id = headers.get("request_id", "")
25
+ body = data.get("body", {})
26
+
27
+ if request_type != "process_files":
28
+ return "No processing done", data
29
+
30
+ input_files = body.get("input_files", [])
31
+ topics = body.get("topics", [])
32
+
33
+ urls, file_key_map = self._extract_urls_and_keys(input_files)
34
+ batch_results = self.processor.process_batch(urls)
35
+ md_context = self._create_markdown_context(batch_results, file_key_map)
36
+
37
+ final_json = self._create_response_json(request_id, input_files, topics, md_context)
38
+ return json.dumps(final_json, ensure_ascii=False), final_json
39
+
40
+ def _extract_urls_and_keys(self, input_files: list) -> Tuple[list, dict]:
41
+ """Extract URLs and create file key mapping"""
42
+ urls = []
43
+ file_key_map = {}
44
+ for f in input_files:
45
+ key = f.get("key", "")
46
+ url = f.get("url", "")
47
+ urls.append(url)
48
+ file_key_map[url] = key
49
+ return urls, file_key_map
50
+
51
+ def _create_markdown_context(self, batch_results: dict, file_key_map: dict) -> list:
52
+ """Create markdown context from batch results"""
53
+ md_context = []
54
+ for url, md_content in batch_results.items():
55
+ key = file_key_map.get(url, "")
56
+ md_context.append({"key": key, "body": md_content})
57
+ return md_context
58
+
59
+ def _create_response_json(self, request_id: str, input_files: list,
60
+ topics: list, md_context: list) -> dict:
61
+ """Create the final response JSON"""
62
+ return {
63
+ "headers": {
64
+ "request_type": "question_extraction_update_from_gpu_server",
65
+ "request_id": request_id
66
+ },
67
+ "body": {
68
+ "input_files": input_files,
69
+ "topics": topics,
70
+ "md_context": md_context
71
+ }
72
+ }
73
+
74
+ class RabbitMQWorker:
75
+ def __init__(self, num_workers: int = 1):
76
+ self.num_workers = num_workers
77
+ self.message_processor = MessageProcessor()
78
+
79
+ def callback(self, ch, method, properties, body):
80
+ """Handle incoming RabbitMQ messages"""
81
+ thread_id = threading.current_thread().name
82
+ headers = properties.headers or {}
83
+
84
+ print(f"[Worker {thread_id}] Received message: {body}, headers: {headers}")
85
+
86
+ if headers.get("process") == "topic_extraction":
87
+ raw_text_outputs, parsed_json_outputs = self.message_processor.process_message(body)
88
+ print(f"[Worker {thread_id}] Pipeline result:\n{raw_text_outputs}")
89
+ else:
90
+ print(f"[Worker {thread_id}] Unknown process, sleeping 10s.")
91
+ time.sleep(10)
92
+ print("[Worker] Done")
93
+
94
+ def worker(self, channel):
95
+ """Worker process to consume messages"""
96
+ try:
97
+ channel.start_consuming()
98
+ except Exception as e:
99
+ print(f"[Worker] Error: {e}")
100
+
101
+ def connect_to_rabbitmq(self):
102
+ """Establish connection to RabbitMQ"""
103
+ rabbit_url = os.getenv("RABBITMQ_URL", "amqp://guest:guest@localhost:5672/")
104
+ connection = pika.BlockingConnection(pika.URLParameters(rabbit_url))
105
+ channel = connection.channel()
106
+
107
+ channel.queue_declare(queue="gpu_server", durable=True)
108
+ channel.basic_qos(prefetch_count=1)
109
+ channel.basic_consume(
110
+ queue="gpu_server",
111
+ on_message_callback=self.callback,
112
+ auto_ack=True
113
+ )
114
+ return connection, channel
115
+
116
+ def start(self):
117
+ """Start the worker threads"""
118
+ print(f"Starting {self.num_workers} workers")
119
+ with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
120
+ for _ in range(self.num_workers):
121
+ connection, channel = self.connect_to_rabbitmq()
122
+ executor.submit(self.worker, channel)
123
+
124
+ def main():
125
+ worker = RabbitMQWorker()
126
  worker.start()