bhoomi19 commited on
Commit
732cb93
·
verified ·
1 Parent(s): 40f0cc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -17,6 +17,14 @@ from docx import Document
17
  from gtts import gTTS
18
  from io import BytesIO
19
  import spacy
 
 
 
 
 
 
 
 
20
 
21
  # -----------------------------
22
  # Page config
@@ -59,10 +67,9 @@ def load_models():
59
 
60
  return tokenizer_simplify, simplify_model, gen_tokenizer, gen_model, nlp, classifier, summarizer
61
 
62
- tokenizer_simplify, simplify_model, gen_tokenizer, gen_model, nlp, classifier, summarizer = load_models()
63
 
64
- # Hugging Face Spaces likely CPU-only
65
- DEVICE = "cpu"
66
  gen_model.to(DEVICE)
67
 
68
  # -----------------------------
@@ -154,7 +161,8 @@ def fairness_score_visual(text, lang):
154
 
155
  def chat_response(prompt, lang, history):
156
  """Persistent memory chat"""
157
- context = "\n".join([f"User: {u}\nAI: {a}" for u, a in history[-3:]]) # Last 3 turns
 
158
  full_prompt = f"You are a helpful multilingual legal assistant. {context}\nUser: {prompt}\nAI:"
159
  inputs = gen_tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
160
  outputs = gen_model.generate(**inputs, max_new_tokens=200, temperature=0.7, top_p=0.9, do_sample=True)
@@ -163,6 +171,7 @@ def chat_response(prompt, lang, history):
163
  response = response.split("AI:")[-1].strip()
164
  return translate_text(response, lang)
165
 
 
166
  # -----------------------------
167
  # Main Streamlit App
168
  # -----------------------------
@@ -214,6 +223,7 @@ def main():
214
  lang = st.selectbox("Chat Language:", LANG_NAMES, index=0)
215
  query = st.text_area("Ask your question:", height=150)
216
 
 
217
  if "chat_history" not in st.session_state:
218
  st.session_state.chat_history = []
219
 
@@ -227,6 +237,7 @@ def main():
227
  if audio:
228
  st.audio(audio, format="audio/mp3")
229
 
 
230
  if st.session_state.chat_history:
231
  st.markdown("### 🧠 Chat History")
232
  for q, a in st.session_state.chat_history[-5:]:
@@ -249,5 +260,6 @@ def main():
249
  *Disclaimer:* Educational use only — not legal advice.
250
  """)
251
 
252
- if __name__ == "__main__":
253
- main()
 
 
17
  from gtts import gTTS
18
  from io import BytesIO
19
  import spacy
20
+ import subprocess
21
+
22
+ # -----------------------------
23
+ # Hugging Face fix: ensure Streamlit runs properly
24
+ # -----------------------------
25
+ #if _name_ == "_main_" and os.environ.get("SYSTEM") == "spaces":
26
+ # subprocess.Popen(["streamlit", "run", "app.py", "--server.port", "7860", "--server.address", "0.0.0.0"])
27
+ # exit()
28
 
29
  # -----------------------------
30
  # Page config
 
67
 
68
  return tokenizer_simplify, simplify_model, gen_tokenizer, gen_model, nlp, classifier, summarizer
69
 
 
70
 
71
+ tokenizer_simplify, simplify_model, gen_tokenizer, gen_model, nlp, classifier, summarizer = load_models()
72
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
73
  gen_model.to(DEVICE)
74
 
75
  # -----------------------------
 
161
 
162
  def chat_response(prompt, lang, history):
163
  """Persistent memory chat"""
164
+ # Combine chat history context
165
+ context = "\n".join([f"User: {u}\nAI: {a}" for u, a in history[-3:]]) # Keep last 3
166
  full_prompt = f"You are a helpful multilingual legal assistant. {context}\nUser: {prompt}\nAI:"
167
  inputs = gen_tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
168
  outputs = gen_model.generate(**inputs, max_new_tokens=200, temperature=0.7, top_p=0.9, do_sample=True)
 
171
  response = response.split("AI:")[-1].strip()
172
  return translate_text(response, lang)
173
 
174
+
175
  # -----------------------------
176
  # Main Streamlit App
177
  # -----------------------------
 
223
  lang = st.selectbox("Chat Language:", LANG_NAMES, index=0)
224
  query = st.text_area("Ask your question:", height=150)
225
 
226
+ # Maintain persistent conversation
227
  if "chat_history" not in st.session_state:
228
  st.session_state.chat_history = []
229
 
 
237
  if audio:
238
  st.audio(audio, format="audio/mp3")
239
 
240
+ # Display conversation history
241
  if st.session_state.chat_history:
242
  st.markdown("### 🧠 Chat History")
243
  for q, a in st.session_state.chat_history[-5:]:
 
260
  *Disclaimer:* Educational use only — not legal advice.
261
  """)
262
 
263
+
264
+ if _name_ == "_main_":
265
+ main()