omar47 commited on
Commit
7f5fade
·
1 Parent(s): 74211b8

Update space

Browse files
Files changed (2) hide show
  1. app.py +238 -45
  2. requirements.txt +230 -1
app.py CHANGED
@@ -1,63 +1,256 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
41
 
42
 
43
  """
44
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
  """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
62
 
63
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import os
3
+ import string
4
+ from pymongo import MongoClient
5
+ from openai import AsyncOpenAI, OpenAI
6
+ import copy
7
+ from constants import *
8
+ import asyncio
9
+ import string as st
10
+ from opik.integrations.openai import track_openai
11
+ from opik import track
12
+ from bson.objectid import ObjectId
13
+ import opik
14
 
 
 
 
 
15
 
16
+ oClient = opik.Opik()
17
+ mdb = MongoClient(
18
+ os.getenv("MONGO_URI")
19
+ ) # , "mongodb://localhost:27017/")) # Default to localhost if not set
20
+ aclient = track_openai(AsyncOpenAI())
21
+ client = track_openai(OpenAI())
22
+ db = mdb["Mindware"]
23
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def purge(d):
26
+ """
27
+ Recursively collect all leaf nodes.
28
+ """
29
+ result = {}
30
+ for k, v in d.items():
31
+ if k == "chat_history":
32
+ pass
33
+ if isinstance(v, dict):
34
+ result.update(purge(v))
35
+ elif isinstance(v, list):
36
+ for idx, d in enumerate(v):
37
 
38
+ if isinstance(d, dict):
39
+ try:
40
+ for k1 in d.keys():
41
+ result[k1] = []
42
+ for k2, v2 in d.items():
43
+ result[k2].append(v2)
44
+ except Exception as e:
45
+ print("Error! Error!", e)
46
+ if k not in result.keys():
47
+ result[k] = []
48
+ result[k].append(d)
49
+ else:
50
+ result[k] = v
51
 
52
+ else:
53
+ result[k] = v
54
+ else:
55
+ result[k] = v
56
+ return result
57
+
58
+
59
+ def deploy(d):
60
+ """
61
+ Recursively deploy all leaf nodes.
62
+ """
63
+ result = {}
64
+ result.update(purge(d))
65
+
66
+ return result
67
 
68
+
69
+ async def chat(prompt, model="gpt-4"):
70
+ text = await aclient.chat.completions.create(
71
+ model=model, messages=[{"role": "user", "content": prompt}]
72
+ )
73
+ return text.choices[0].message.content
74
+
75
+
76
+ async def chat_generator(prompt: str, user: dict = None):
77
+ response = await aclient.chat.completions.create(
78
+ model="gpt-4",
79
+ messages=[
80
+ {
81
+ "role": "user",
82
+ "content": prompt.format(**deploy(user)) if user else prompt,
83
+ }
84
+ ],
85
  stream=True,
86
+ )
87
+ # reply_chunks = []
88
+ async for chunk in response:
89
+ if chunk and chunk.choices[0].delta.content:
90
+
91
+ # reply_chunks.append(chunk.choices[0].delta.content)
92
+ yield chunk.choices[0].delta.content
93
+
94
+
95
+ def get_or_init_user(reply, userId):
96
+
97
+ users = list(db["users"].find({"userId": userId}))
98
+ if not users:
99
+ user = dict(**copy.deepcopy(USER_TEMPLATE))
100
+ user.update({"userId": userId})
101
+ user.update({"user_query": reply})
102
+ print("user created:", user)
103
+
104
+ # user.update({"chat_history": history})
105
+
106
+ else:
107
+ user = users[0]
108
+ user.pop("_id")
109
+ user = dict(users[0])
110
+ user.update({"user_query": reply})
111
+
112
+ return user
113
+
114
+
115
+ async def search(query, n=5):
116
+ embed = await aclient.embeddings.create(input=query, model="text-embedding-3-small")
117
+
118
+ query_embedding = embed.data[0].embedding
119
+ pipeline = [
120
+ {
121
+ "$vectorSearch": {
122
+ "queryVector": query_embedding,
123
+ "path": "embedding",
124
+ "index": "arrestor",
125
+ "score": {"$meta": "vectorSearchScore"},
126
+ "filter": {"class": "THERAPIST"},
127
+ "numCandidates": 850,
128
+ "limit": n,
129
+ }
130
+ }
131
+ ]
132
+
133
+ projection = [{"$project": {"embedding": 0}}]
134
+
135
+ pipeline += projection
136
+
137
+ docs = db["runway"].aggregate(pipeline)
138
+ return list(docs)
139
+
140
+
141
+ async def agentic_search(query, n=5):
142
+ results = await search(query, n * 5)
143
+ tasks = [chat(ARAG_PROMPT.format(query=query, doc=doc)) for doc in results]
144
+ is_context = asyncio.gather(*tasks)
145
+
146
+ docs = []
147
+ for doc, reply in zip(results, is_context):
148
+ if reply == "True":
149
+ docs.append(doc)
150
+ if len(docs) >= n:
151
+ break
152
+
153
+ return list(docs)
154
+
155
+
156
+ async def update_user(response, user):
157
+ user["last_question"] = response
158
+ user["chat_history"].append({"role": "user", "content": user["user_query"]})
159
+ user["chat_history"].append({"role": "assistant", "content": response})
160
+ db["users"].delete_many({"userId": user["userId"]})
161
+ db["users"].insert_one(user)
162
+
163
+ print("Updated", user["userId"], user["user_query"], "reply:", response)
164
+
165
+
166
+ async def add_background_tasks(task):
167
+ """a dummy wrapper to be replaced with FastAPI background tasks"""
168
+ await task
169
+
170
+
171
+ punc_removal = str.maketrans("", "", string.punctuation.replace("_", ""))
172
+
173
+
174
+ async def escalate(user):
175
+ print(f"user {user.get('name')} is not working, escalating to clinician")
176
+
177
+
178
+ async def update_docs(user):
179
+ print(f"updating docs for {user['name']}: {user['cache']}")
180
+
181
+
182
+ @track
183
+ async def handle_chat(reply, userId):
184
+ """
185
+ Handle the chat response and update the user
186
+ """
187
+ user = get_or_init_user(reply, userId)
188
+ prompt = BASE_PROMPT
189
+ tasks = [
190
+ chat(prompt=p.format(**deploy(user)))
191
+ for p in [INTENT_PROMPT, RISK_PROMPT, CACHE_PROMPT, INTENSITY_PROMPT]
192
+ ]
193
+ responses = await asyncio.gather(*tasks)
194
+
195
+ if responses[2]:
196
+ user["cache"] = responses[2]
197
+ await add_background_tasks(update_docs(user))
198
+
199
+ intent = responses[0].upper().translate(punc_removal).replace(" ", "_")
200
+
201
+ if intent == "ACTIVE_SPEAKING":
202
+ prompt += SPEAKING_PROMPT
203
+ elif intent == "VALIDATION_SEEK":
204
+ prompt += VALIDATION_PROMPT
205
+ elif intent == "OVERWHELMED":
206
+ prompt += OVERWHELMED_PROMPT
207
+ await asyncio.sleep(5)
208
+ elif intent == "REMOTE_REFERRAL":
209
+ results = await search(user["cache"], n=5)
210
+ prompt += REMOTE_PROMPT.format(results=results, **deploy(user))
211
+ elif intent == "NEUTRAL_STOP":
212
+ prompt += STOP_PROMPT
213
+ elif intent == "END_OF_NARRATIVE":
214
+ prompt += END_PROMPT
215
+ else:
216
+ print("Unknown response of intent detection:", responses[0])
217
+
218
+ if responses[1].upper().translate(punc_removal).replace(" ", "_") == "HIGH_RISK":
219
+ prompt += HIGH_RISK_PROMPT
220
+ await add_background_tasks(escalate(user))
221
+
222
+ response = ""
223
+ async for word in chat_generator(prompt, user):
224
+ if word:
225
+ response += word
226
+ yield word
227
+
228
+ await add_background_tasks(update_user(response, user))
229
+
230
+ return
231
+
232
 
233
+ async def respond(message, history, id):
234
+ """
235
+ Respond to the chat message and return the response.
236
+ """
237
+ reply = ""
238
+ async for r in handle_chat(message, id):
239
+ if r:
240
+ reply += r
241
+ yield reply
242
 
243
 
244
  """
245
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
246
  """
247
+ with gr.Blocks() as demo:
248
+ id = gr.Textbox(str(ObjectId()), label="userID")
249
+ gr.ChatInterface(
250
+ fn=respond,
251
+ type="messages",
252
+ additional_inputs=[id],
253
+ )
 
 
 
 
 
 
 
 
254
 
255
 
256
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1 +1,230 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub==0.25.2
2
+ aiofiles==24.1.0
3
+ aiohappyeyeballs==2.6.1
4
+ aiohttp==3.12.14
5
+ aiosignal==1.4.0
6
+ annotated-types==0.7.0
7
+ anyio==4.9.0
8
+ argon2-cffi==25.1.0
9
+ argon2-cffi-bindings==25.1.0
10
+ arrow==1.3.0
11
+ asgiref==3.8.1
12
+ asttokens==3.0.0
13
+ async-lru==2.0.5
14
+ attrs==25.3.0
15
+ Authlib==1.3.1
16
+ babel==2.17.0
17
+ beautifulsoup4==4.13.4
18
+ bleach==6.2.0
19
+ boto3-stubs==1.39.9
20
+ botocore-stubs==1.38.46
21
+ Brotli==1.1.0
22
+ certifi==2025.7.14
23
+ cffi==1.17.1
24
+ charset-normalizer==3.4.2
25
+ click==8.2.1
26
+ cobble==0.1.4
27
+ colorama==0.4.6
28
+ coloredlogs==15.0.1
29
+ comm==0.2.2
30
+ contourpy==1.3.2
31
+ cryptography==45.0.5
32
+ cycler==0.12.1
33
+ debugpy==1.8.15
34
+ decorator==5.2.1
35
+ defusedxml==0.7.1
36
+ dill==0.3.8
37
+ distro==1.9.0
38
+ dnspython==2.7.0
39
+ executing==2.2.0
40
+ fastapi==0.116.1
41
+ fastjsonschema==2.21.1
42
+ ffmpy==0.6.1
43
+ filelock==3.18.0
44
+ flake8==3.9.2
45
+ flatbuffers==25.2.10
46
+ fonttools==4.59.0
47
+ fqdn==1.5.1
48
+ frozenlist==1.7.0
49
+ fsspec==2025.7.0
50
+ gradio==5.41.0
51
+ gradio_client==1.11.0
52
+ groovy==0.1.2
53
+ h11==0.16.0
54
+ httpcore==1.0.9
55
+ httpx==0.27.0
56
+ huggingface-hub==0.34.3
57
+ humanfriendly==10.0
58
+ idna==3.10
59
+ imageio==2.37.0
60
+ imageio-ffmpeg==0.6.0
61
+ importlib_metadata==8.7.0
62
+ impyute==0.0.8
63
+ iniconfig==2.1.0
64
+ ipykernel==6.29.5
65
+ ipython==9.4.0
66
+ ipython_pygments_lexers==1.1.1
67
+ ipywidgets==8.1.7
68
+ isoduration==20.11.0
69
+ jedi==0.19.2
70
+ Jinja2==3.1.6
71
+ jiter==0.10.0
72
+ json5==0.12.0
73
+ jsonpointer==3.0.0
74
+ jsonschema==4.25.0
75
+ jsonschema-specifications==2025.4.1
76
+ jupyter==1.1.1
77
+ jupyter-console==6.6.3
78
+ jupyter-events==0.12.0
79
+ jupyter-lsp==2.2.6
80
+ jupyter_client==8.6.3
81
+ jupyter_core==5.8.1
82
+ jupyter_server==2.16.0
83
+ jupyter_server_terminals==0.5.3
84
+ jupyterlab==4.4.5
85
+ jupyterlab_pygments==0.3.0
86
+ jupyterlab_server==2.27.3
87
+ jupyterlab_widgets==3.0.15
88
+ kiwisolver==1.4.8
89
+ lark==1.2.2
90
+ litellm==1.74.4
91
+ lxml==6.0.0
92
+ magika==0.6.2
93
+ mammoth==1.9.1
94
+ markdown-it-py==3.0.0
95
+ markdownify==1.1.0
96
+ markitdown==0.1.2
97
+ MarkupSafe==3.0.2
98
+ matplotlib==3.10.3
99
+ matplotlib-inline==0.1.7
100
+ mccabe==0.6.1
101
+ mdurl==0.1.2
102
+ mistune==3.1.3
103
+ moviepy==2.2.1
104
+ mpmath==1.3.0
105
+ multidict==6.6.3
106
+ multiprocess==0.70.16
107
+ mypy-boto3-bedrock-runtime==1.39.7
108
+ nbclient==0.10.2
109
+ nbconvert==7.16.6
110
+ nbformat==5.10.4
111
+ nbqa==1.8.5
112
+ nest-asyncio==1.6.0
113
+ notebook==7.4.5
114
+ notebook_shim==0.2.4
115
+ numpy==2.3.1
116
+ onnxruntime==1.22.1
117
+ openai==1.97.0
118
+ opik==1.8.6
119
+ orjson==3.11.1
120
+ overrides==7.7.0
121
+ packaging==25.0
122
+ pandas==2.3.1
123
+ pandas-stubs==2.2.2.240603
124
+ pandocfilters==1.5.1
125
+ parso==0.8.4
126
+ pdfminer.six==20250506
127
+ pillow==11.3.0
128
+ platformdirs==4.3.8
129
+ pluggy==1.6.0
130
+ proglog==0.1.12
131
+ prometheus_client==0.22.1
132
+ prompt_toolkit==3.0.51
133
+ propcache==0.3.2
134
+ protobuf==6.31.1
135
+ psutil==7.0.0
136
+ pure_eval==0.2.3
137
+ pycodestyle==2.12.0
138
+ pycodestyle_magic==0.5
139
+ pycparser==2.22
140
+ pydantic==2.11.7
141
+ pydantic-settings==2.10.1
142
+ pydantic_core==2.33.2
143
+ pydub==0.25.1
144
+ pyflakes==2.3.1
145
+ Pygments==2.19.2
146
+ pymongo==4.13.2
147
+ pyparsing==3.2.3
148
+ pyreadline3==3.5.4
149
+ pytest==8.4.1
150
+ python-dateutil==2.9.0.post0
151
+ python-dotenv==1.1.1
152
+ python-json-logger==3.3.0
153
+ python-multipart==0.0.20
154
+ python-pptx==1.0.2
155
+ pytz==2025.2
156
+ pywin32==311
157
+ pywinpty==2.0.15
158
+ PyYAML==6.0.2
159
+ pyzmq==27.0.0
160
+ RapidFuzz==3.13.0
161
+ referencing==0.36.2
162
+ regex==2024.11.6
163
+ requests==2.32.4
164
+ rfc3339-validator==0.1.4
165
+ rfc3986-validator==0.1.1
166
+ rfc3987-syntax==1.1.0
167
+ rich==14.0.0
168
+ rpds-py==0.26.0
169
+ ruff==0.12.7
170
+ safehttpx==0.1.6
171
+ schwab-py==1.3.0
172
+ schwabdev==2.1.1
173
+ semantic-version==2.10.0
174
+ Send2Trash==1.8.3
175
+ sentry-sdk==2.33.0
176
+ shellingham==1.5.4
177
+ six==1.17.0
178
+ sniffio==1.3.1
179
+ soupsieve==2.7
180
+ sqlparse==0.5.0
181
+ stack-data==0.6.3
182
+ starlette==0.47.1
183
+ sympy==1.14.0
184
+ tenacity==9.1.2
185
+ terminado==0.18.1
186
+ tiktoken==0.9.0
187
+ tinycss2==1.4.0
188
+ tokenize-rt==5.2.0
189
+ tokenizers==0.21.2
190
+ tomli==2.0.1
191
+ tomlkit==0.13.3
192
+ tornado==6.5.1
193
+ tqdm==4.67.1
194
+ traitlets==5.14.3
195
+ typer==0.16.0
196
+ types-awscrt==0.27.4
197
+ types-bleach==6.1.0.20240331
198
+ types-colorama==0.4.15.20240311
199
+ types-croniter==2.0.0.20240423
200
+ types-decorator==5.1.8.20240310
201
+ types-docutils==0.21.0.20240423
202
+ types-html5lib==1.1.11.20240228
203
+ types-jsonschema==4.22.0.20240610
204
+ types-Markdown==3.6.0.20240316
205
+ types-Pillow==10.2.0.20240520
206
+ types-psutil==5.9.5.20240516
207
+ types-Pygments==2.18.0.20240506
208
+ types-python-dateutil==2.9.0.20240316
209
+ types-pytz==2024.1.0.20240417
210
+ types-PyYAML==6.0.12.20240311
211
+ types-requests==2.32.0.20240602
212
+ types-s3transfer==0.13.0
213
+ types-setuptools==70.0.0.20240524
214
+ types-tqdm==4.66.0.20240417
215
+ typing-inspection==0.4.1
216
+ typing_extensions==4.14.1
217
+ tzdata==2025.2
218
+ uri-template==1.3.0
219
+ urllib3==2.5.0
220
+ uuid6==2025.0.1
221
+ uvicorn==0.35.0
222
+ wcwidth==0.2.13
223
+ webcolors==24.11.1
224
+ webencodings==0.5.1
225
+ websocket-client==1.8.0
226
+ websockets==12.0
227
+ widgetsnbextension==4.0.14
228
+ xlsxwriter==3.2.5
229
+ yarl==1.20.1
230
+ zipp==3.23.0