sifujohn commited on
Commit
d9983c0
·
verified ·
1 Parent(s): 8881491

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -3
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import streamlit as st
 
2
  from huggingface_hub import InferenceClient
3
 
4
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
@@ -7,9 +8,30 @@ def format_prompt(agent_id, message):
7
  return f"<s><Agent {agent_id}>[INST] {message} [/INST]"
8
 
9
  def generate_response(agent_id, message):
10
- formatted_prompt = format_prompt(agent_id, message)
11
- response = client.text_generation(formatted_prompt, stream=False)
12
- return response['generated_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def generate_group_therapy_responses(message):
15
  responses = []
 
1
  import streamlit as st
2
+ import os
3
  from huggingface_hub import InferenceClient
4
 
5
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
 
8
  return f"<s><Agent {agent_id}>[INST] {message} [/INST]"
9
 
10
  def generate_response(agent_id, message):
11
+ api_url = "https://api-inference.huggingface.co/models/mistralai/mistral-tiny" # Replace with your model path
12
+ headers = {
13
+ 'Authorization': f'Bearer {os.getenv("HUGGINGFACE_API_KEY")}'
14
+ }
15
+ data = {
16
+ "inputs": {
17
+ "past_user_inputs": [],
18
+ "generated_responses": [],
19
+ "text": message
20
+ }
21
+ }
22
+
23
+ response = requests.post(api_url, headers=headers, json=data)
24
+ if response.status_code == 200:
25
+ response_data = response.json() # Parse the JSON response into a dictionary
26
+ # Ensure that 'generated_text' is accessed from a dictionary
27
+ if isinstance(response_data, dict) and 'generated_text' in response_data:
28
+ return response_data['generated_text']
29
+ else:
30
+ # Handle unexpected response format
31
+ return "Received an unexpected response format from the API."
32
+ else:
33
+ return f"Error: {response.status_code}"
34
+
35
 
36
  def generate_group_therapy_responses(message):
37
  responses = []