Update app.py
Browse files
app.py
CHANGED
|
@@ -5,17 +5,18 @@ from dotenv import load_dotenv
|
|
| 5 |
from xhtml2pdf import pisa
|
| 6 |
import io
|
| 7 |
from textwrap import dedent
|
| 8 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 9 |
|
| 10 |
-
#
|
| 11 |
def load_resources():
|
| 12 |
load_dotenv()
|
| 13 |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
|
| 14 |
subprocess.run(["huggingface-cli", "login", "--token", huggingface_token], capture_output=True)
|
| 15 |
-
tokenizer = AutoTokenizer.from_pretrained("istiak101/TinyLlama-1.1B-Chat-v0.6-rag-
|
| 16 |
-
model = AutoModelForCausalLM.from_pretrained("istiak101/TinyLlama-1.1B-Chat-v0.6-rag-
|
| 17 |
return model, tokenizer
|
| 18 |
|
|
|
|
| 19 |
def create_test_prompt(question, context, tokenizer):
|
| 20 |
prompt = dedent(
|
| 21 |
f"""
|
|
@@ -39,7 +40,7 @@ def create_test_prompt(question, context, tokenizer):
|
|
| 39 |
messages, tokenize=False, add_generation_prompt=True
|
| 40 |
)
|
| 41 |
|
| 42 |
-
#
|
| 43 |
if "llama_model" not in st.session_state or "llama_tokenizer" not in st.session_state:
|
| 44 |
model, tokenizer = load_resources()
|
| 45 |
st.session_state.llama_model = model
|
|
@@ -64,7 +65,7 @@ def get_llama_response(query):
|
|
| 64 |
outputs = pipe(query)
|
| 65 |
return outputs[0]["generated_text"]
|
| 66 |
|
| 67 |
-
#
|
| 68 |
def generate_pdf(convo, topic):
|
| 69 |
html = f"<h2>{topic}</h2><hr>"
|
| 70 |
for msg in convo:
|
|
@@ -79,7 +80,7 @@ def generate_pdf(convo, topic):
|
|
| 79 |
return None
|
| 80 |
return result
|
| 81 |
|
| 82 |
-
#
|
| 83 |
if "chat_sessions" not in st.session_state:
|
| 84 |
st.session_state.chat_sessions = {}
|
| 85 |
if "current_conversation" not in st.session_state:
|
|
@@ -87,10 +88,10 @@ if "current_conversation" not in st.session_state:
|
|
| 87 |
if "edit_mode" not in st.session_state:
|
| 88 |
st.session_state.edit_mode = {}
|
| 89 |
|
| 90 |
-
#
|
| 91 |
st.title("💬 Tell Me Why")
|
| 92 |
|
| 93 |
-
#
|
| 94 |
st.markdown("""
|
| 95 |
<style>
|
| 96 |
.chat-wrapper {
|
|
@@ -150,7 +151,7 @@ st.markdown("""
|
|
| 150 |
</style>
|
| 151 |
""", unsafe_allow_html=True)
|
| 152 |
|
| 153 |
-
#
|
| 154 |
st.sidebar.title("Conversations")
|
| 155 |
titles = list(st.session_state.chat_sessions.keys())
|
| 156 |
|
|
@@ -168,7 +169,7 @@ if titles:
|
|
| 168 |
else:
|
| 169 |
st.sidebar.write("No conversations yet. Start one below!")
|
| 170 |
|
| 171 |
-
#
|
| 172 |
with st.sidebar.form(key='new_conversation_form', clear_on_submit=True):
|
| 173 |
new_topic = st.text_input("New Conversation Name")
|
| 174 |
submit_button = st.form_submit_button("Start New Conversation")
|
|
@@ -184,7 +185,7 @@ with st.sidebar.form(key='new_conversation_form', clear_on_submit=True):
|
|
| 184 |
else:
|
| 185 |
st.sidebar.warning("Conversation already exists!")
|
| 186 |
|
| 187 |
-
#
|
| 188 |
if st.session_state.current_conversation:
|
| 189 |
convo = st.session_state.chat_sessions[st.session_state.current_conversation]
|
| 190 |
|
|
@@ -249,8 +250,7 @@ if st.session_state.current_conversation:
|
|
| 249 |
|
| 250 |
st.markdown('</div>', unsafe_allow_html=True)
|
| 251 |
|
| 252 |
-
#
|
| 253 |
-
# --- Generate Assistant Response ---
|
| 254 |
if len(convo) % 2 == 1:
|
| 255 |
last_user_msg = convo[-1]["text"]
|
| 256 |
with st.spinner("Generating response..."):
|
|
|
|
| 5 |
from xhtml2pdf import pisa
|
| 6 |
import io
|
| 7 |
from textwrap import dedent
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 9 |
|
| 10 |
+
# Load Model Resources
|
| 11 |
def load_resources():
|
| 12 |
load_dotenv()
|
| 13 |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
|
| 14 |
subprocess.run(["huggingface-cli", "login", "--token", huggingface_token], capture_output=True)
|
| 15 |
+
tokenizer = AutoTokenizer.from_pretrained("istiak101/TinyLlama-1.1B-Chat-v0.6-rag-finetunedv2.0")
|
| 16 |
+
model = AutoModelForCausalLM.from_pretrained("istiak101/TinyLlama-1.1B-Chat-v0.6-rag-finetunedv2.0")
|
| 17 |
return model, tokenizer
|
| 18 |
|
| 19 |
+
# Chat Prompt
|
| 20 |
def create_test_prompt(question, context, tokenizer):
|
| 21 |
prompt = dedent(
|
| 22 |
f"""
|
|
|
|
| 40 |
messages, tokenize=False, add_generation_prompt=True
|
| 41 |
)
|
| 42 |
|
| 43 |
+
# Store model and tokenizer in session state
|
| 44 |
if "llama_model" not in st.session_state or "llama_tokenizer" not in st.session_state:
|
| 45 |
model, tokenizer = load_resources()
|
| 46 |
st.session_state.llama_model = model
|
|
|
|
| 65 |
outputs = pipe(query)
|
| 66 |
return outputs[0]["generated_text"]
|
| 67 |
|
| 68 |
+
# PDF Generation
|
| 69 |
def generate_pdf(convo, topic):
|
| 70 |
html = f"<h2>{topic}</h2><hr>"
|
| 71 |
for msg in convo:
|
|
|
|
| 80 |
return None
|
| 81 |
return result
|
| 82 |
|
| 83 |
+
# Session Init
|
| 84 |
if "chat_sessions" not in st.session_state:
|
| 85 |
st.session_state.chat_sessions = {}
|
| 86 |
if "current_conversation" not in st.session_state:
|
|
|
|
| 88 |
if "edit_mode" not in st.session_state:
|
| 89 |
st.session_state.edit_mode = {}
|
| 90 |
|
| 91 |
+
# App Title
|
| 92 |
st.title("💬 Tell Me Why")
|
| 93 |
|
| 94 |
+
# Custom CSS
|
| 95 |
st.markdown("""
|
| 96 |
<style>
|
| 97 |
.chat-wrapper {
|
|
|
|
| 151 |
</style>
|
| 152 |
""", unsafe_allow_html=True)
|
| 153 |
|
| 154 |
+
# Sidebar: Conversations
|
| 155 |
st.sidebar.title("Conversations")
|
| 156 |
titles = list(st.session_state.chat_sessions.keys())
|
| 157 |
|
|
|
|
| 169 |
else:
|
| 170 |
st.sidebar.write("No conversations yet. Start one below!")
|
| 171 |
|
| 172 |
+
# New Conversation
|
| 173 |
with st.sidebar.form(key='new_conversation_form', clear_on_submit=True):
|
| 174 |
new_topic = st.text_input("New Conversation Name")
|
| 175 |
submit_button = st.form_submit_button("Start New Conversation")
|
|
|
|
| 185 |
else:
|
| 186 |
st.sidebar.warning("Conversation already exists!")
|
| 187 |
|
| 188 |
+
# Main Chat Area
|
| 189 |
if st.session_state.current_conversation:
|
| 190 |
convo = st.session_state.chat_sessions[st.session_state.current_conversation]
|
| 191 |
|
|
|
|
| 250 |
|
| 251 |
st.markdown('</div>', unsafe_allow_html=True)
|
| 252 |
|
| 253 |
+
# User Prompt
|
|
|
|
| 254 |
if len(convo) % 2 == 1:
|
| 255 |
last_user_msg = convo[-1]["text"]
|
| 256 |
with st.spinner("Generating response..."):
|