File size: 3,804 Bytes
bb89f50
 
 
 
 
 
 
 
33fd638
3ce4cf9
bb89f50
 
3ce4cf9
1dec06e
3ce4cf9
 
 
 
bb89f50
 
 
 
 
3ce4cf9
 
 
bb89f50
 
03615c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb89f50
75dd79a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03615c4
bb89f50
03615c4
 
 
 
 
 
 
 
 
89bc1e6
 
33fd638
03615c4
 
 
33fd638
bb89f50
 
 
 
 
 
03615c4
bb89f50
 
 
 
03615c4
 
bb89f50
03615c4
 
bb89f50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03615c4
 
 
 
bb89f50
 
669e816
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Consumer
import time
import pika
import os
from Server import get_response
import json
from agent.agent_graph.StateTasks import ProblemState
import argparse
import redis
from encryption_utils import decrypt_token_from_json 


##################################################
# VARIABLES 
##################################################
# args for this file
argparse_model = argparse.ArgumentParser()
argparse_model.add_argument("--id", type=int, default=0, help="Consumer ID")
consumer_id = argparse_model.parse_args().id


RABBITMQ_URL = os.environ["RABBITMQ_URL"]
QUEUE_NAME = os.environ["QUEUE_NAME"] 
redis_host = os.environ["REDIS_HOST"]
redis_port = os.environ["REDIS_PORT"]   
redis_password = os.environ["REDIS_PASSWORD"]


##################################################
# PROCESSING METHODS
##################################################

def redis_send(user_id,msg_id,answer):
    r = redis.Redis(
        host=redis_host,
        port=redis_port,
        decode_responses=True,
        username="default",
        password=redis_password,
    )
 
    success = r.set(f'ANSWER_FOR_USER_ID{user_id}_OF_{msg_id}',json.dumps(answer))
    return success


def model_call(request, token): 
    # تأكد إن request dict
    if isinstance(request, str):
        request = json.loads(request)

    # fill with last state  
    try: 
        state = json.loads(request.get('last_state', "")) if request.get('last_state') else {}
    except Exception: 
        state = {}

    # fallback لو مفيش state
    if not state:
        state = { 
            "question": request.get('prompt', ""), 
            "memory": request.get('memory', []) 
        } 

    answer = get_response(
        request.get('prompt', ""), 
        request.get('memory', []),
        token,
        state,
        request.get('user_email', ""),
        request.get('user_name', "")
    ) 
     
    # drop unserializable keys 
    for k in ["llm", "rag_model"]: 
        if k in answer:
            answer[k] = "" 

    return answer

def process_message(recieved_msg):
    # decrypt token
    token = decrypt_token_from_json(json.loads(recieved_msg['ht_token_encrypted_dumped']))
    # call the model
    model_answer = model_call(recieved_msg,token)
    # send answer to redis
    user_id = recieved_msg["user_id"]
    msg_id = recieved_msg["msg_id"]
    redis_send_res = redis_send(user_id,msg_id,model_answer)
    print({"STATUS": redis_send_res , "CONSUMER": {consumer_id}}) # add monitoring but still hide user data


##################################################
# CONSUMER METHODS
##################################################


def get_connection():
    params = pika.URLParameters(RABBITMQ_URL)
    return pika.BlockingConnection(params)

def callback(ch, method, properties, body):
    ##### Recieve message and process it
    recieved_msg = json.loads(body.decode())
    print("-------------------------------------------------")
    print(f"MSG AT CONSUMER {consumer_id}" )

    ##### Process Message
    process_message(recieved_msg)


    ###### Finalize
    ch.basic_ack(delivery_tag=method.delivery_tag)

def start_consumer():
    # when scalled each server has consumer
    params = pika.URLParameters(RABBITMQ_URL)
    connection = pika.BlockingConnection(params)
    channel = connection.channel()

    channel.queue_declare(queue=QUEUE_NAME, durable=True)

    channel.basic_qos(prefetch_count=1)

    channel.basic_consume(
        queue=QUEUE_NAME,
        on_message_callback=callback
    )

    print("Waiting for messages...")
    channel.start_consuming()

##################################################
# MAIN
##################################################

if __name__ == "__main__":
    print(f"Starting New Consumer {consumer_id}...")
    start_consumer()