|
|
import os |
|
|
import subprocess |
|
|
import streamlit as st |
|
|
from dotenv import load_dotenv |
|
|
from xhtml2pdf import pisa |
|
|
import io |
|
|
from textwrap import dedent |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
|
|
|
|
|
|
def load_resources(): |
|
|
load_dotenv() |
|
|
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") |
|
|
subprocess.run(["huggingface-cli", "login", "--token", huggingface_token], capture_output=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained("istiak101/TinyLlama-1.1B-Chat-v0.6-rag-finetunedv2.0") |
|
|
model = AutoModelForCausalLM.from_pretrained("istiak101/TinyLlama-1.1B-Chat-v0.6-rag-finetunedv2.0") |
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
def create_test_prompt(question, context, tokenizer): |
|
|
prompt = dedent( |
|
|
f""" |
|
|
{question} |
|
|
|
|
|
Information: |
|
|
|
|
|
``` |
|
|
{context} |
|
|
``` |
|
|
""" |
|
|
) |
|
|
messages = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": "Use only the information to answer the question", |
|
|
}, |
|
|
{"role": "user", "content": prompt}, |
|
|
] |
|
|
return tokenizer.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
if "llama_model" not in st.session_state or "llama_tokenizer" not in st.session_state: |
|
|
model, tokenizer = load_resources() |
|
|
st.session_state.llama_model = model |
|
|
st.session_state.llama_tokenizer = tokenizer |
|
|
|
|
|
st.set_page_config(page_title="Tell Me Why", page_icon="β", layout="wide") |
|
|
|
|
|
def get_llama_response(query): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipe = pipeline( |
|
|
task="text-generation", |
|
|
model=st.session_state.llama_model, |
|
|
tokenizer=st.session_state.llama_tokenizer, |
|
|
max_new_tokens=128, |
|
|
return_full_text=False, |
|
|
) |
|
|
outputs = pipe(query) |
|
|
return outputs[0]["generated_text"] |
|
|
|
|
|
|
|
|
def generate_pdf(convo, topic): |
|
|
html = f"<h2>{topic}</h2><hr>" |
|
|
for msg in convo: |
|
|
if msg["role"] == "user": |
|
|
html += f"<p><strong>You:</strong> {msg['text']}</p>" |
|
|
elif msg["role"] == "assistant": |
|
|
html += f"<p><strong>AI Assistant:</strong> {msg['text']}</p>" |
|
|
|
|
|
result = io.BytesIO() |
|
|
pisa_status = pisa.CreatePDF(io.StringIO(html), dest=result) |
|
|
if pisa_status.err: |
|
|
return None |
|
|
return result |
|
|
|
|
|
|
|
|
if "chat_sessions" not in st.session_state: |
|
|
st.session_state.chat_sessions = {} |
|
|
if "current_conversation" not in st.session_state: |
|
|
st.session_state.current_conversation = None |
|
|
if "edit_mode" not in st.session_state: |
|
|
st.session_state.edit_mode = {} |
|
|
|
|
|
|
|
|
st.title("π¬ Tell Me Why") |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
.chat-wrapper { |
|
|
display: flex; |
|
|
flex-direction: column; |
|
|
} |
|
|
|
|
|
.message-row { |
|
|
display: flex; |
|
|
align-items: flex-end; |
|
|
justify-content: flex-end; |
|
|
margin-top: 10px; |
|
|
} |
|
|
|
|
|
.user-bubble { |
|
|
background-color: rgba(0, 200, 83, 0.1); |
|
|
color: var(--text-color); |
|
|
padding: 15px; |
|
|
border-radius: 15px; |
|
|
max-width: 80%; |
|
|
border: 1px solid rgba(0, 200, 83, 0.4); |
|
|
position: relative; |
|
|
margin-bottom: 20px; |
|
|
} |
|
|
|
|
|
.assistant-bubble { |
|
|
background-color: rgba(3, 169, 244, 0.1); |
|
|
color: var(--text-color); |
|
|
padding: 15px; |
|
|
border-radius: 15px; |
|
|
max-width: 80%; |
|
|
align-self: flex-start; |
|
|
margin-right: auto; |
|
|
border: 1px solid rgba(3, 169, 244, 0.3); |
|
|
margin-top: 10px; |
|
|
margin-bottom: 40px; |
|
|
} |
|
|
|
|
|
.bubble-header { |
|
|
font-size: 14px; |
|
|
font-weight: bold; |
|
|
margin-bottom: 5px; |
|
|
display: flex; |
|
|
align-items: center; |
|
|
} |
|
|
|
|
|
.bubble-header span { |
|
|
margin-left: 5px; |
|
|
} |
|
|
|
|
|
.icon-col { |
|
|
display: flex; |
|
|
flex-direction: column; |
|
|
gap: 5px; |
|
|
align-items: center; |
|
|
} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
st.sidebar.title("Conversations") |
|
|
titles = list(st.session_state.chat_sessions.keys()) |
|
|
|
|
|
if titles: |
|
|
for topic in titles: |
|
|
col1, col2 = st.sidebar.columns([0.8, 0.2]) |
|
|
if col1.button(f"π¨ {topic}", key=f"select_{topic}"): |
|
|
st.session_state.current_conversation = topic |
|
|
st.rerun() |
|
|
if col2.button("π", key=f"delete_{topic}"): |
|
|
del st.session_state.chat_sessions[topic] |
|
|
if st.session_state.current_conversation == topic: |
|
|
st.session_state.current_conversation = None |
|
|
st.rerun() |
|
|
else: |
|
|
st.sidebar.write("No conversations yet. Start one below!") |
|
|
|
|
|
|
|
|
with st.sidebar.form(key='new_conversation_form', clear_on_submit=True): |
|
|
new_topic = st.text_input("New Conversation Name") |
|
|
submit_button = st.form_submit_button("Start New Conversation") |
|
|
|
|
|
if submit_button: |
|
|
if new_topic.strip() and new_topic not in st.session_state.chat_sessions: |
|
|
st.session_state.chat_sessions[new_topic] = [] |
|
|
st.session_state.current_conversation = new_topic |
|
|
st.sidebar.success(f"Started new conversation: {new_topic}") |
|
|
st.rerun() |
|
|
elif not new_topic.strip(): |
|
|
st.sidebar.warning("Please enter a name.") |
|
|
else: |
|
|
st.sidebar.warning("Conversation already exists!") |
|
|
|
|
|
|
|
|
if st.session_state.current_conversation: |
|
|
convo = st.session_state.chat_sessions[st.session_state.current_conversation] |
|
|
|
|
|
st.markdown('<div class="chat-wrapper">', unsafe_allow_html=True) |
|
|
|
|
|
for idx, msg in enumerate(convo): |
|
|
with st.container(): |
|
|
if msg["role"] == "user": |
|
|
if st.session_state.edit_mode.get(idx, False): |
|
|
|
|
|
question_input, context_input = msg["text"].split("<br><br>") |
|
|
|
|
|
question_input = question_input.replace("Question: ", "") |
|
|
context_input = context_input.replace("Context: ", "") |
|
|
|
|
|
|
|
|
new_question = st.text_input("Edit your question:", value=question_input, key=f"edit_question_{idx}") |
|
|
new_context = st.text_area("Edit your context:", value=context_input, key=f"edit_context_{idx}") |
|
|
|
|
|
prompt = create_test_prompt(new_question, new_context, st.session_state.llama_tokenizer) |
|
|
col1, col2 = st.columns([1, 1]) |
|
|
with col1: |
|
|
if st.button("β
Save", key=f"save_{idx}"): |
|
|
|
|
|
new_combined_input = f"{new_question}<br><br>{new_context}" |
|
|
msg["text"] = new_combined_input |
|
|
with st.spinner("Generating response..."): |
|
|
try: |
|
|
new_response = get_llama_response(prompt) |
|
|
except: |
|
|
new_response = "Failed to retrieve response." |
|
|
if idx + 1 < len(convo) and convo[idx + 1]["role"] == "assistant": |
|
|
convo[idx + 1]["text"] = new_response |
|
|
st.session_state.edit_mode[idx] = False |
|
|
st.session_state.chat_sessions[st.session_state.current_conversation] = convo |
|
|
st.rerun() |
|
|
with col2: |
|
|
if st.button("β Cancel", key=f"cancel_{idx}"): |
|
|
st.session_state.edit_mode[idx] = False |
|
|
st.rerun() |
|
|
else: |
|
|
col1, col2 = st.columns([0.1, 0.9]) |
|
|
with col1: |
|
|
if st.button("βοΈ", key=f"edit_btn_{idx}"): |
|
|
st.session_state.edit_mode[idx] = True |
|
|
st.rerun() |
|
|
with col2: |
|
|
st.markdown(f''' |
|
|
<div class="user-bubble"> |
|
|
<div class="bubble-header">π€ <span>You</span></div> |
|
|
{msg["text"]} |
|
|
</div> |
|
|
''', unsafe_allow_html=True) |
|
|
|
|
|
elif msg["role"] == "assistant": |
|
|
st.markdown(f''' |
|
|
<div class="assistant-bubble"> |
|
|
<div class="bubble-header">π <span>AI Assistant</span></div> |
|
|
{msg["text"]} |
|
|
</div> |
|
|
''', unsafe_allow_html=True) |
|
|
|
|
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
if len(convo) % 2 == 1: |
|
|
last_user_msg = convo[-1]["text"] |
|
|
question_input, context_input = last_user_msg.split("<br><br>") |
|
|
question_input = question_input.replace("Question: ", "") |
|
|
context_input = context_input.replace("Context: ", "") |
|
|
|
|
|
prompt = create_test_prompt(question_input, context_input, st.session_state.llama_tokenizer) |
|
|
with st.spinner("Generating response..."): |
|
|
try: |
|
|
assistant_reply = get_llama_response(prompt) |
|
|
|
|
|
except Exception as e: |
|
|
assistant_reply = f"β οΈ Failed to generate response" |
|
|
convo.append({"role": "assistant", "text": assistant_reply}) |
|
|
st.session_state.chat_sessions[st.session_state.current_conversation] = convo |
|
|
st.rerun() |
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
if st.button("π₯ Export Conversation as PDF"): |
|
|
pdf_bytes = generate_pdf(convo, st.session_state.current_conversation) |
|
|
if pdf_bytes: |
|
|
st.download_button("Download PDF", pdf_bytes, file_name="TellMeWhy_Conversation.pdf", mime="application/pdf") |
|
|
else: |
|
|
st.error("β Failed to generate PDF.") |
|
|
|
|
|
with st.form(key="submit_form", clear_on_submit=True): |
|
|
question_input = st.text_input("Enter your question:") |
|
|
context_input = st.text_area("Enter your context:") |
|
|
submit_button = st.form_submit_button("Submit") |
|
|
if submit_button: |
|
|
if question_input and context_input: |
|
|
combined_input = f"Question: {question_input}<br><br>Context: {context_input}" |
|
|
convo.append({"role": "user", "text": combined_input}) |
|
|
st.session_state.chat_sessions[st.session_state.current_conversation] = convo |
|
|
st.rerun() |