File size: 4,941 Bytes
512cc66
 
40f70ca
 
512cc66
40f70ca
 
 
512cc66
40f70ca
 
 
512cc66
40f70ca
 
 
512cc66
40f70ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512cc66
40f70ca
 
 
512cc66
40f70ca
dedf1d0
 
 
 
 
 
 
 
40f70ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512cc66
 
40f70ca
512cc66
40f70ca
 
 
 
 
 
 
dedf1d0
 
40f70ca
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from huggingface_hub import InferenceClient
import gradio as gr
import logging
import sys

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

try:
    client = InferenceClient(
        "mistralai/Mixtral-8x7B-Instruct-v0.1"
    )
except Exception as e:
    logger.error(f"Failed to initialize Hugging Face client: {str(e)}")
    sys.exit(1)

def format_prompt(message, history):
    try:
        prompt = "<s>"
        if history:
            for user_prompt, bot_response in history:
                prompt += f"[INST] {user_prompt} [/INST]"
                prompt += f" {bot_response}</s> "
        prompt += f"[INST] {message} [/INST]"
        logger.info(f"Formatted prompt: {prompt}")
        return prompt
    except Exception as e:
        logger.error(f"Error in format_prompt: {str(e)}")
        return None

def generate(message, chat_history, system_prompt, temperature=0.9, max_new_tokens=512, top_p=0.95):
    try:
        logger.info(f"Received message: {message}")
        logger.info(f"System prompt: {system_prompt}")
        
        # Format the full prompt
        if not chat_history:
            full_message = f"{system_prompt}\n\nUser: {message}"
        else:
            full_message = message
            
        formatted_prompt = format_prompt(full_message, chat_history)
        
        if not formatted_prompt:
            return "I encountered an error formatting your message. Please try again."
            
        # Generation parameters
        generate_kwargs = dict(
            temperature=float(temperature),
            max_new_tokens=int(max_new_tokens),
            top_p=float(top_p),
            do_sample=True,
            seed=42,
        )
        
        logger.info("Starting generation with parameters: %s", generate_kwargs)
        
        # Generate response
        response_stream = client.text_generation(
            formatted_prompt,
            **generate_kwargs,
            stream=True,
            details=True,
            return_full_text=False
        )
        
        partial_message = ""
        for response in response_stream:
            if response.token.text:
                partial_message += response.token.text
                yield partial_message

    except Exception as e:
        logger.error(f"Error in generate function: {str(e)}")
        yield f"I encountered an error: {str(e)}"

# Define the default system prompt
DEFAULT_SYSTEM_PROMPT = """You are a supportive AI assistant trained to provide emotional support and general guidance. 
Remember to: 
1. Show empathy and understanding
2. Ask clarifying questions when needed
3. Provide practical coping strategies
4. Encourage professional help when appropriate
5. Maintain boundaries and ethical guidelines"""

# Define the interface
with gr.Blocks() as demo:
    chatbot = gr.Chatbot(height=500)
    msg = gr.Textbox(label="Message", placeholder="Type your message here...")
    
    with gr.Accordion("Advanced Options", open=False):
        system_prompt = gr.Textbox(
            value=DEFAULT_SYSTEM_PROMPT,
            label="System Prompt",
            lines=3
        )
        temperature = gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.9,
            step=0.1,
            label="Temperature"
        )
        max_new_tokens = gr.Slider(
            minimum=64,
            maximum=1024,
            value=512,
            step=64,
            label="Max Tokens"
        )

    clear = gr.Button("Clear")

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(history, system_prompt, temperature, max_new_tokens):
        if not history:
            return history
            
        user_message = history[-1][0]
        history[-1][1] = ""
        
        for chunk in generate(
            user_message, 
            history[:-1], 
            system_prompt, 
            temperature, 
            max_new_tokens
        ):
            history[-1][1] = chunk
            yield history

    msg.submit(
        user,
        [msg, chatbot],
        [msg, chatbot],
        queue=False
    ).then(
        bot,
        [chatbot, system_prompt, temperature, max_new_tokens],
        chatbot
    )

    clear.click(lambda: None, None, chatbot, queue=False)

    gr.Markdown("""
    # PsyAssist - ADVANCING MENTAL HEALTH SUPPORT WITH AI-DRIVEN INTERACTION
    
    **Important Notice**: This is an AI-powered mental health support chatbot. While it can provide emotional support 
    and general guidance, it is not a replacement for professional mental health services. In case of emergency, 
    please contact your local mental health crisis hotline.
    """)

if __name__ == "__main__":
    try:
        demo.queue().launch(show_api=False)
    except Exception as e:
        logger.error(f"Failed to launch Gradio interface: {str(e)}")
        sys.exit(1)