Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,6 +17,14 @@ from docx import Document
|
|
| 17 |
from gtts import gTTS
|
| 18 |
from io import BytesIO
|
| 19 |
import spacy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# -----------------------------
|
| 22 |
# Page config
|
|
@@ -59,10 +67,9 @@ def load_models():
|
|
| 59 |
|
| 60 |
return tokenizer_simplify, simplify_model, gen_tokenizer, gen_model, nlp, classifier, summarizer
|
| 61 |
|
| 62 |
-
tokenizer_simplify, simplify_model, gen_tokenizer, gen_model, nlp, classifier, summarizer = load_models()
|
| 63 |
|
| 64 |
-
|
| 65 |
-
DEVICE = "cpu"
|
| 66 |
gen_model.to(DEVICE)
|
| 67 |
|
| 68 |
# -----------------------------
|
|
@@ -154,7 +161,8 @@ def fairness_score_visual(text, lang):
|
|
| 154 |
|
| 155 |
def chat_response(prompt, lang, history):
|
| 156 |
"""Persistent memory chat"""
|
| 157 |
-
|
|
|
|
| 158 |
full_prompt = f"You are a helpful multilingual legal assistant. {context}\nUser: {prompt}\nAI:"
|
| 159 |
inputs = gen_tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
|
| 160 |
outputs = gen_model.generate(**inputs, max_new_tokens=200, temperature=0.7, top_p=0.9, do_sample=True)
|
|
@@ -163,6 +171,7 @@ def chat_response(prompt, lang, history):
|
|
| 163 |
response = response.split("AI:")[-1].strip()
|
| 164 |
return translate_text(response, lang)
|
| 165 |
|
|
|
|
| 166 |
# -----------------------------
|
| 167 |
# Main Streamlit App
|
| 168 |
# -----------------------------
|
|
@@ -214,6 +223,7 @@ def main():
|
|
| 214 |
lang = st.selectbox("Chat Language:", LANG_NAMES, index=0)
|
| 215 |
query = st.text_area("Ask your question:", height=150)
|
| 216 |
|
|
|
|
| 217 |
if "chat_history" not in st.session_state:
|
| 218 |
st.session_state.chat_history = []
|
| 219 |
|
|
@@ -227,6 +237,7 @@ def main():
|
|
| 227 |
if audio:
|
| 228 |
st.audio(audio, format="audio/mp3")
|
| 229 |
|
|
|
|
| 230 |
if st.session_state.chat_history:
|
| 231 |
st.markdown("### 🧠 Chat History")
|
| 232 |
for q, a in st.session_state.chat_history[-5:]:
|
|
@@ -249,5 +260,6 @@ def main():
|
|
| 249 |
*Disclaimer:* Educational use only — not legal advice.
|
| 250 |
""")
|
| 251 |
|
| 252 |
-
|
| 253 |
-
|
|
|
|
|
|
| 17 |
from gtts import gTTS
|
| 18 |
from io import BytesIO
|
| 19 |
import spacy
|
| 20 |
+
import subprocess
|
| 21 |
+
|
| 22 |
+
# -----------------------------
|
| 23 |
+
# Hugging Face fix: ensure Streamlit runs properly
|
| 24 |
+
# -----------------------------
|
| 25 |
+
#if _name_ == "_main_" and os.environ.get("SYSTEM") == "spaces":
|
| 26 |
+
# subprocess.Popen(["streamlit", "run", "app.py", "--server.port", "7860", "--server.address", "0.0.0.0"])
|
| 27 |
+
# exit()
|
| 28 |
|
| 29 |
# -----------------------------
|
| 30 |
# Page config
|
|
|
|
| 67 |
|
| 68 |
return tokenizer_simplify, simplify_model, gen_tokenizer, gen_model, nlp, classifier, summarizer
|
| 69 |
|
|
|
|
| 70 |
|
| 71 |
+
tokenizer_simplify, simplify_model, gen_tokenizer, gen_model, nlp, classifier, summarizer = load_models()
|
| 72 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 73 |
gen_model.to(DEVICE)
|
| 74 |
|
| 75 |
# -----------------------------
|
|
|
|
| 161 |
|
| 162 |
def chat_response(prompt, lang, history):
|
| 163 |
"""Persistent memory chat"""
|
| 164 |
+
# Combine chat history context
|
| 165 |
+
context = "\n".join([f"User: {u}\nAI: {a}" for u, a in history[-3:]]) # Keep last 3
|
| 166 |
full_prompt = f"You are a helpful multilingual legal assistant. {context}\nUser: {prompt}\nAI:"
|
| 167 |
inputs = gen_tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
|
| 168 |
outputs = gen_model.generate(**inputs, max_new_tokens=200, temperature=0.7, top_p=0.9, do_sample=True)
|
|
|
|
| 171 |
response = response.split("AI:")[-1].strip()
|
| 172 |
return translate_text(response, lang)
|
| 173 |
|
| 174 |
+
|
| 175 |
# -----------------------------
|
| 176 |
# Main Streamlit App
|
| 177 |
# -----------------------------
|
|
|
|
| 223 |
lang = st.selectbox("Chat Language:", LANG_NAMES, index=0)
|
| 224 |
query = st.text_area("Ask your question:", height=150)
|
| 225 |
|
| 226 |
+
# Maintain persistent conversation
|
| 227 |
if "chat_history" not in st.session_state:
|
| 228 |
st.session_state.chat_history = []
|
| 229 |
|
|
|
|
| 237 |
if audio:
|
| 238 |
st.audio(audio, format="audio/mp3")
|
| 239 |
|
| 240 |
+
# Display conversation history
|
| 241 |
if st.session_state.chat_history:
|
| 242 |
st.markdown("### 🧠 Chat History")
|
| 243 |
for q, a in st.session_state.chat_history[-5:]:
|
|
|
|
| 260 |
*Disclaimer:* Educational use only — not legal advice.
|
| 261 |
""")
|
| 262 |
|
| 263 |
+
|
| 264 |
+
if _name_ == "_main_":
|
| 265 |
+
main()
|