File size: 4,023 Bytes
efc351e
 
 
 
 
53cc210
54e4df0
efc351e
 
 
 
 
 
 
 
 
 
 
 
53cc210
e3cc7e5
 
53cc210
b8df990
3ee9c30
42f5635
53cc210
 
 
 
 
 
 
e3cc7e5
53cc210
54e4df0
3ee9c30
1229a9d
6035193
e3cc7e5
 
53cc210
e3cc7e5
20bc52b
 
e3cc7e5
 
 
 
 
 
 
 
20bc52b
e3cc7e5
20bc52b
e3cc7e5
20bc52b
e3cc7e5
53cc210
e3cc7e5
20bc52b
 
 
 
3ee9c30
 
cf878a5
3ee9c30
 
 
 
e3cc7e5
54e4df0
 
3ee9c30
 
 
e3cc7e5
 
3ee9c30
 
6035193
e3cc7e5
3ee9c30
e3cc7e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6035193
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from huggingface_hub import login
from peft import PeftModel, PeftConfig
import time

# Login with HF_TOKEN (if available)
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
    try:
        login(token=hf_token, add_to_git_credential=False)
        st.success("Hugging Face login successful!")
    except Exception as e:
        st.error(f"Hugging Face login failed: {e}")
else:
    st.warning("HF_TOKEN environment variable not set. Some features may be limited.")

# Model and Adapter Configuration
model_id = "Prajjwalng/gemma_customer_care"  # Base model
adapter_id = "Prajjwalng/gemma_customercare_adapters"  # adapter model

# Initialize model and tokenizer (load only once)
@st.cache_resource
def load_model(model_id):
    base_model = AutoModelForCausalLM.from_pretrained(
        model_id,
        low_cpu_mem_usage=True,
        return_dict=True,
        torch_dtype=torch.float16,
        device_map={"": 0} if torch.cuda.is_available() else "cpu"
    )

    tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)
    return base_model, tokenizer

merged_model, tokenizer = load_model(model_id)

# Function to generate chatbot response using the provided template
def get_completion(query: str, model, tokenizer) -> str:
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    prompt_template = f"""
<start_of_turn>system You are a support chatbot who helps with user queries chatbot who always responds in the style of a professional.\n<end_of_turn>
<start_of_turn>user


{query}
<end_of_turn>

<start_of_turn>model
"""
    prompt = prompt_template.format(query=query)

    encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)

    model_inputs = encodeds.to(device)

    model.to(device)

    generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
    decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    model_response = decoded.split("model\n")[-1].strip()
    return model_response

# Streamlit app
st.title("Customer Care ChatBot")

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []
    # Add initial welcome message
    initial_message = {"role": "assistant", "content": "Hi, I am Sora, I am your customer support agent."}
    st.session_state.messages.append(initial_message)

# Display chat messages from history on app rerun
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# Accept user input
if prompt := st.chat_input("How can I help you?"):
    # Add user message to chat history
    st.session_state.messages.append({"role": "user", "content": prompt})
    # Display user message in chat message container
    with st.chat_message("user"):
        st.markdown(prompt)

    # Generate and display chatbot response
    with st.chat_message("assistant"):
        message_placeholder = st.empty()
        typing_placeholder = st.empty()
        typing_dots = ""  # Initialize empty string for typing dots

        # Animate typing dots
        for i in range(3):
            typing_dots += "."
            typing_placeholder.markdown(typing_dots)
            time.sleep(0.3)  # Adjust speed as needed

        typing_placeholder.empty()  # Clear typing dots

        full_response = ""
        response = get_completion(prompt, merged_model, tokenizer)

        # Simulate stream of responses with milliseconds delay
        for chunk in response.split():
            full_response += chunk + " "
            time.sleep(0.05)
            # Add a placeholder to stream the response
            message_placeholder.markdown(full_response + "▌")
        message_placeholder.markdown(full_response)

    # Add assistant response to chat history
    st.session_state.messages.append({"role": "assistant", "content": full_response})