ZBro7 commited on
Commit
35c109d
·
verified ·
1 Parent(s): cb0c171

Update router.py

Browse files
Files changed (1) hide show
  1. router.py +18 -26
router.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import asyncio
2
  import time
3
  import requests
@@ -19,9 +20,7 @@ from rag_engine import rag_response
19
  # =====================================
20
 
21
  IMAGE_SPACE_URL = "https://your-image-space.hf.space/generate"
22
-
23
  CACHE_TTL_SECONDS = 300 # 5 minutes
24
-
25
  response_cache = {}
26
 
27
 
@@ -54,7 +53,6 @@ def set_cache(cache_key, response):
54
  # =====================================
55
 
56
  def build_messages(system_prompt, memory, user_prompt):
57
-
58
  messages = []
59
 
60
  if system_prompt:
@@ -67,19 +65,19 @@ def build_messages(system_prompt, memory, user_prompt):
67
 
68
 
69
  # =====================================
70
- # IMAGE SERVICE
71
  # =====================================
72
 
73
- def call_image_microservice(prompt):
74
-
75
  try:
76
- response = requests.post(
77
- IMAGE_SPACE_URL,
78
- json={"prompt": prompt},
79
- timeout=60
 
 
80
  )
81
- return response.json()
82
- except:
83
  return {"error": "Image service unavailable"}
84
 
85
 
@@ -99,7 +97,7 @@ async def async_gemini(messages):
99
  # MAIN ROUTER
100
  # =====================================
101
 
102
- def route_request(prompt, user_id):
103
 
104
  cache_key = f"{user_id}:{prompt}"
105
 
@@ -115,7 +113,7 @@ def route_request(prompt, user_id):
115
  # ==========================
116
  if prompt.startswith("/image"):
117
  clean_prompt = prompt.replace("/image", "").strip()
118
- return call_image_microservice(clean_prompt)
119
 
120
  # ==========================
121
  # RAG QUICK RESPONSE
@@ -134,7 +132,6 @@ def route_request(prompt, user_id):
134
  # CLASSIFY
135
  # ==========================
136
  classification = classify_prompt(prompt)
137
-
138
  intent = classification.get("intent", "chat")
139
  needs_search = classification.get("needs_search", False)
140
 
@@ -158,7 +155,7 @@ def route_request(prompt, user_id):
158
  if intent == "reasoning":
159
 
160
  messages = build_messages(system_prompt, memory, prompt)
161
- response = call_gemini(messages)
162
 
163
  save_message(user_id, "user", prompt)
164
  save_message(user_id, "assistant", response)
@@ -185,16 +182,11 @@ Use web data if helpful.
185
 
186
  messages = build_messages(system_prompt, memory, enriched_prompt)
187
 
188
- async def run_parallel():
189
- llama_task = asyncio.create_task(async_llama(messages))
190
- gemini_task = asyncio.create_task(async_gemini(messages))
191
-
192
- llama_answer = await llama_task
193
- gemini_answer = await gemini_task
194
-
195
- return llama_answer, gemini_answer
196
 
197
- llama_answer, gemini_answer = asyncio.run(run_parallel())
 
198
 
199
  winner = judge_answers(llama_answer, gemini_answer)
200
  final_answer = gemini_answer if winner == 2 else llama_answer
@@ -211,7 +203,7 @@ Use web data if helpful.
211
  # ==========================
212
  messages = build_messages(system_prompt, memory, prompt)
213
 
214
- response = call_llama(messages)
215
 
216
  save_message(user_id, "user", prompt)
217
  save_message(user_id, "assistant", response)
 
1
+
2
  import asyncio
3
  import time
4
  import requests
 
20
  # =====================================
21
 
22
  IMAGE_SPACE_URL = "https://your-image-space.hf.space/generate"
 
23
  CACHE_TTL_SECONDS = 300 # 5 minutes
 
24
  response_cache = {}
25
 
26
 
 
53
  # =====================================
54
 
55
  def build_messages(system_prompt, memory, user_prompt):
 
56
  messages = []
57
 
58
  if system_prompt:
 
65
 
66
 
67
  # =====================================
68
+ # IMAGE SERVICE (Async Safe)
69
  # =====================================
70
 
71
+ async def call_image_microservice(prompt):
 
72
  try:
73
+ return await asyncio.to_thread(
74
+ lambda: requests.post(
75
+ IMAGE_SPACE_URL,
76
+ json={"prompt": prompt},
77
+ timeout=60
78
+ ).json()
79
  )
80
+ except Exception:
 
81
  return {"error": "Image service unavailable"}
82
 
83
 
 
97
  # MAIN ROUTER
98
  # =====================================
99
 
100
+ async def route_request(prompt, user_id):
101
 
102
  cache_key = f"{user_id}:{prompt}"
103
 
 
113
  # ==========================
114
  if prompt.startswith("/image"):
115
  clean_prompt = prompt.replace("/image", "").strip()
116
+ return await call_image_microservice(clean_prompt)
117
 
118
  # ==========================
119
  # RAG QUICK RESPONSE
 
132
  # CLASSIFY
133
  # ==========================
134
  classification = classify_prompt(prompt)
 
135
  intent = classification.get("intent", "chat")
136
  needs_search = classification.get("needs_search", False)
137
 
 
155
  if intent == "reasoning":
156
 
157
  messages = build_messages(system_prompt, memory, prompt)
158
+ response = await async_gemini(messages)
159
 
160
  save_message(user_id, "user", prompt)
161
  save_message(user_id, "assistant", response)
 
182
 
183
  messages = build_messages(system_prompt, memory, enriched_prompt)
184
 
185
+ llama_task = asyncio.create_task(async_llama(messages))
186
+ gemini_task = asyncio.create_task(async_gemini(messages))
 
 
 
 
 
 
187
 
188
+ llama_answer = await llama_task
189
+ gemini_answer = await gemini_task
190
 
191
  winner = judge_answers(llama_answer, gemini_answer)
192
  final_answer = gemini_answer if winner == 2 else llama_answer
 
203
  # ==========================
204
  messages = build_messages(system_prompt, memory, prompt)
205
 
206
+ response = await async_llama(messages)
207
 
208
  save_message(user_id, "user", prompt)
209
  save_message(user_id, "assistant", response)