TahaFawzyElshrif commited on
Commit
bb89f50
·
1 Parent(s): 4347b3c
Files changed (1) hide show
  1. Consumer.py +82 -0
Consumer.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Consumer
2
+ import time
3
+ import pika
4
+ import os
5
+ from Server import get_response
6
+ import json
7
+ from agent.agent_graph.StateTasks import ProblemState
8
+ import argparse
9
+ argparse_model = argparse.ArgumentParser()
10
+ argparse_model.add_argument("--id", type=int, default=0, help="Consumer ID")
11
+
12
+
13
+
14
+ consumer_id = argparse_model.parse_args().id
15
+
16
+
17
+ RABBITMQ_URL = os.environ["RABBITMQ_URL"]
18
+ QUEUE_NAME = os.environ["QUEUE_NAME"]
19
+
20
+
21
+
22
+ def model_call(request):
23
+ # fill with last state
24
+ try:
25
+ state = json.loads(request['last_state'])
26
+ except Exception:
27
+ state: ProblemState = {
28
+ "question": request['prompt'],
29
+ "memory": request['memory']
30
+ }
31
+
32
+ print(f"MODEL CALL WITH STATE {state} and PROMPT {request['prompt']} and MEMORY {request['memory']} and HT_TOKEN {request['ht_token']} and USER_EMAIL {request['user_email']} and USER_NAME {request['user_name']}")
33
+ request['ht_token'] ="hf_" + request['ht_token']
34
+ #answer = get_response(request['prompt'], request['memory'],request['ht_token'],state,request['user_email'],request['user_name'])
35
+
36
+ #print(f"ANSWER {answer}")
37
+ # drop unserlizable keys
38
+ #for k in ["llm","rag_model"]:
39
+ # answer[k] = ""
40
+
41
+
42
+ #return {"Data": answer}
43
+
44
+ def get_connection():
45
+ params = pika.URLParameters(RABBITMQ_URL)
46
+ return pika.BlockingConnection(params)
47
+
48
+ def callback(ch, method, properties, body):
49
+ recieved_msg = json.loads(body.decode())
50
+ print("-------------------------------------------------")
51
+ print(f"MSG AT CONSUMER {consumer_id}" )
52
+
53
+ # simulate processing
54
+ print(f"TYPE {type(recieved_msg)}, CONTENT {recieved_msg}")
55
+ model_call(recieved_msg)
56
+
57
+ # (put your logic here)
58
+ print(f"CONSUMER {consumer_id}:::: Processing done")
59
+
60
+ ch.basic_ack(delivery_tag=method.delivery_tag)
61
+
62
+ def start_consumer():
63
+ # when scalled each server has consumer
64
+ params = pika.URLParameters(RABBITMQ_URL)
65
+ connection = pika.BlockingConnection(params)
66
+ channel = connection.channel()
67
+
68
+ channel.queue_declare(queue=QUEUE_NAME, durable=True)
69
+
70
+ channel.basic_qos(prefetch_count=1)
71
+
72
+ channel.basic_consume(
73
+ queue=QUEUE_NAME,
74
+ on_message_callback=callback
75
+ )
76
+
77
+ print("Waiting for messages...")
78
+ channel.start_consuming()
79
+
80
+ if __name__ == "__main__":
81
+ print(f"Starting New Consumer {consumer_id}...")
82
+ start_consumer()