gaja1995 commited on
Commit
e28b3ce
·
verified ·
1 Parent(s): 9564121

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -61
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import streamlit as st
2
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
3
- import torch
4
  from langdetect import detect
5
  import time
6
  import warnings
 
7
 
8
  # Suppress warnings
9
  warnings.filterwarnings("ignore")
@@ -16,36 +16,28 @@ st.set_page_config(
16
  initial_sidebar_state="expanded"
17
  )
18
 
19
- # Load models (with caching to avoid reloading)
 
 
 
 
20
  @st.cache_resource
21
- def load_models():
22
- # Translation models
23
- en_to_hi = pipeline("translation", model="Helsinki-NLP/opus-mt-en-hi")
24
- hi_to_en = pipeline("translation", model="Helsinki-NLP/opus-mt-hi-en")
25
-
26
- en_to_ta = pipeline("translation", model="Helsinki-NLP/opus-mt-en-ta")
27
- ta_to_en = pipeline("translation", model="Helsinki-NLP/opus-mt-ta-en")
28
-
29
- # For other languages, we'll use a multilingual model
30
- multilingual_translator = pipeline("translation", model="facebook/mbart-large-50-many-to-many-mmt")
31
-
32
- # Load GUVI-specific model (fine-tuned GPT)
33
- guvi_tokenizer = AutoTokenizer.from_pretrained("gpt2")
34
- guvi_model = AutoModelForSeq2SeqLM.from_pretrained("gpt2")
35
-
36
- return {
37
- "en_to_hi": en_to_hi,
38
- "hi_to_en": hi_to_en,
39
- "en_to_ta": en_to_ta,
40
- "ta_to_en": ta_to_en,
41
- "multilingual": multilingual_translator,
42
- "guvi_tokenizer": guvi_tokenizer,
43
- "guvi_model": guvi_model
44
- }
45
 
46
- # Initialize models
47
- with st.spinner("Loading models... This may take a few minutes."):
48
- models = load_models()
49
 
50
  # Language mapping
51
  language_map = {
@@ -66,32 +58,30 @@ def detect_language(text):
66
  except:
67
  return "en"
68
 
69
- # Function to translate text
70
- def translate_text(text, source_lang, target_lang):
71
  if source_lang == target_lang:
72
  return text
73
 
74
- # Handle specific language pairs with dedicated models
75
- if source_lang == "en" and target_lang == "hi":
76
- return models["en_to_hi"](text)[0]['translation_text']
77
- elif source_lang == "hi" and target_lang == "en":
78
- return models["hi_to_en"](text)[0]['translation_text']
79
- elif source_lang == "en" and target_lang == "ta":
80
- return models["en_to_ta"](text)[0]['translation_text']
81
- elif source_lang == "ta" and target_lang == "en":
82
- return models["ta_to_en"](text)[0]['translation_text']
83
- else:
84
- # Use multilingual model for other languages
85
- return models["multilingual"](text, src_lang=source_lang, tgt_lang=target_lang)[0]['translation_text']
86
 
87
- # Function to generate GUVI-specific response
88
- def generate_guvi_response(prompt):
89
- # Tokenize the input
90
- inputs = models["guvi_tokenizer"](prompt, return_tensors="pt", max_length=512, truncation=True)
 
 
 
 
 
91
 
92
- # Generate response
93
  with torch.no_grad():
94
- outputs = models["guvi_model"].generate(
95
  **inputs,
96
  max_length=200,
97
  num_beams=5,
@@ -99,10 +89,8 @@ def generate_guvi_response(prompt):
99
  temperature=0.7
100
  )
101
 
102
- # Decode the output
103
- response = models["guvi_tokenizer"].decode(outputs[0], skip_special_tokens=True)
104
-
105
- return response
106
 
107
  # Streamlit UI
108
  def main():
@@ -164,8 +152,8 @@ def main():
164
  st.sidebar.markdown("### About")
165
  st.sidebar.markdown("""
166
  This chatbot is powered by:
167
- - Hugging Face Transformers
168
- - Streamlit
169
  - GUVI's custom knowledge base
170
 
171
  Developed for GUVI's multilingual learners.
@@ -197,18 +185,18 @@ def main():
197
  with st.spinner("Thinking..."):
198
  # Translate to English if needed
199
  if input_lang != "en":
200
- translated_prompt = translate_text(prompt, input_lang, "en")
201
  else:
202
  translated_prompt = prompt
203
 
204
- # Generate response from GUVI model
205
- guvi_response = generate_guvi_response(translated_prompt)
206
 
207
  # Translate back to user's language if needed
208
  if target_lang != "en":
209
- final_response = translate_text(guvi_response, "en", target_lang)
210
  else:
211
- final_response = guvi_response
212
 
213
  # Add a small delay for natural conversation flow
214
  time.sleep(0.5)
 
1
  import streamlit as st
2
+ from googletrans import Translator
 
3
  from langdetect import detect
4
  import time
5
  import warnings
6
+ import os
7
 
8
  # Suppress warnings
9
  warnings.filterwarnings("ignore")
 
16
  initial_sidebar_state="expanded"
17
  )
18
 
19
+
20
+ # Initialize Google Translator
21
+ translator = Translator()
22
+
23
+ # Load GUVI dataset
24
  @st.cache_resource
25
+ def load_guvi_dataset():
26
+ qa_pairs = {}
27
+ try:
28
+ with open("GUVI dataset.txt", "r", encoding="utf-8") as file:
29
+ lines = file.readlines()
30
+ for i in range(0, len(lines), 2):
31
+ if i+1 < len(lines):
32
+ question = lines[i].strip()
33
+ answer = lines[i+1].strip()
34
+ qa_pairs[question.lower()] = answer
35
+ except FileNotFoundError:
36
+ st.error("GUVI dataset (guvi.txt) not found. Using GPT-only responses.")
37
+ return qa_pairs
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # Initialize dataset
40
+ qa_pairs = load_guvi_dataset()
 
41
 
42
  # Language mapping
43
  language_map = {
 
58
  except:
59
  return "en"
60
 
61
+ # Function to translate text using Google Translator
62
+ def translate_text(text, target_lang, source_lang='auto'):
63
  if source_lang == target_lang:
64
  return text
65
 
66
+ try:
67
+ translation = translator.translate(text, src=source_lang, dest=target_lang)
68
+ return translation.text
69
+ except Exception as e:
70
+ st.warning(f"Translation error: {e}. Returning original text.")
71
+ return text
 
 
 
 
 
 
72
 
73
+ # Function to generate response using GPT or GUVI dataset
74
+ def generate_response(prompt):
75
+ # First check if the question exists in our GUVI dataset
76
+ lower_prompt = prompt.lower()
77
+ if lower_prompt in qa_pairs:
78
+ return qa_pairs[lower_prompt]
79
+
80
+ # If not found in dataset, use Hugging Face model
81
+ inputs = models["chat_tokenizer"](prompt, return_tensors="pt", max_length=512, truncation=True)
82
 
 
83
  with torch.no_grad():
84
+ outputs = models["chat_model"].generate(
85
  **inputs,
86
  max_length=200,
87
  num_beams=5,
 
89
  temperature=0.7
90
  )
91
 
92
+ return models["chat_tokenizer"].decode(outputs[0], skip_special_tokens=True)
93
+
 
 
94
 
95
  # Streamlit UI
96
  def main():
 
152
  st.sidebar.markdown("### About")
153
  st.sidebar.markdown("""
154
  This chatbot is powered by:
155
+ - OpenAI GPT
156
+ - Google Translator
157
  - GUVI's custom knowledge base
158
 
159
  Developed for GUVI's multilingual learners.
 
185
  with st.spinner("Thinking..."):
186
  # Translate to English if needed
187
  if input_lang != "en":
188
+ translated_prompt = translate_text(prompt, "en", input_lang)
189
  else:
190
  translated_prompt = prompt
191
 
192
+ # Generate response
193
+ response = generate_response(translated_prompt)
194
 
195
  # Translate back to user's language if needed
196
  if target_lang != "en":
197
+ final_response = translate_text(response, target_lang, "en")
198
  else:
199
+ final_response = response
200
 
201
  # Add a small delay for natural conversation flow
202
  time.sleep(0.5)