pradeep4321 commited on
Commit
e8141e8
·
verified ·
1 Parent(s): e36d297

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +50 -158
src/streamlit_app.py CHANGED
@@ -1,194 +1,86 @@
1
  import streamlit as st
2
- from huggingface_hub import InferenceClient
 
3
  from langdetect import detect
4
- import numpy as np
5
- import faiss
6
- import tempfile
7
- import speech_recognition as sr
8
- from sentence_transformers import SentenceTransformer
9
- import os
10
 
11
  # ==============================
12
- # CONFIG
13
  # ==============================
14
- HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
 
 
15
 
16
- if not HF_TOKEN:
17
- st.error("❌ HF_TOKEN not found. Add it in Hugging Face Secrets.")
18
- st.stop()
19
-
20
- client = InferenceClient(
21
- model="google/gemma-7b-it",
22
- token=HF_TOKEN
23
- )
24
-
25
- embed_model = SentenceTransformer("all-MiniLM-L6-v2")
26
-
27
- # ==============================
28
- # FAISS MEMORY
29
- # ==============================
30
- dimension = 384
31
- index = faiss.IndexFlatL2(dimension)
32
- memory_texts = []
33
-
34
- def embed(text):
35
- return embed_model.encode(text).astype("float32")
36
-
37
- def store_memory(src, tgt):
38
- text_pair = f"{src} -> {tgt}"
39
- vec = embed(text_pair)
40
- index.add(np.array([vec]))
41
- memory_texts.append(text_pair)
42
-
43
- def retrieve_memory(query):
44
- if len(memory_texts) == 0:
45
- return None
46
- vec = embed(query)
47
- D, I = index.search(np.array([vec]), k=1)
48
- return memory_texts[I[0][0]]
49
 
50
  # ==============================
51
- # SAFE LANGUAGE DETECTION
52
  # ==============================
53
- def safe_detect(text):
54
- text = text.strip()
55
-
56
- # Avoid wrong detection for short text
57
- if len(text) < 5:
58
- return "auto"
59
-
60
- try:
61
- return detect(text)
62
- except:
63
- return "auto"
64
-
65
- # ==============================
66
- # SPEECH TO TEXT
67
- # ==============================
68
- def speech_to_text(audio_file):
69
- recognizer = sr.Recognizer()
70
-
71
- with tempfile.NamedTemporaryFile(delete=False) as temp_audio:
72
- temp_audio.write(audio_file.read())
73
- temp_audio_path = temp_audio.name
74
-
75
- with sr.AudioFile(temp_audio_path) as source:
76
- audio = recognizer.record(source)
77
-
78
- try:
79
- text = recognizer.recognize_google(audio)
80
- except:
81
- text = ""
82
-
83
- return text
84
 
85
  # ==============================
86
  # TRANSLATION FUNCTION
87
  # ==============================
88
  def translate(text, target_lang):
89
 
90
- src_lang = safe_detect(text)
91
- memory = retrieve_memory(text)
92
-
93
- # 🚨 Handle very short input
94
- if len(text.split()) <= 1:
95
- return "⚠️ Please enter a full sentence for better translation.", src_lang, memory
96
-
97
- # Prompt design
98
- if src_lang == "auto":
99
- prompt = f"""
100
- You are a professional multilingual translator.
101
-
102
- Detect the language and translate into {target_lang}.
103
-
104
- Text:
105
- {text}
106
-
107
- Rules:
108
- - Only return translated text
109
- - No explanation
110
- """
111
- else:
112
- prompt = f"""
113
- You are a professional multilingual translator.
114
-
115
- Translate from {src_lang} to {target_lang}.
116
-
117
- Text:
118
- {text}
119
-
120
- Rules:
121
- - Only return translated text
122
- - No explanation
123
- """
124
-
125
  try:
126
- response = client.text_generation(
127
- prompt,
128
- max_new_tokens=150,
129
- temperature=0.2,
130
- top_p=0.9
131
- )
132
 
133
- translated = response.strip()
134
 
135
- # Handle empty or bad output
136
- if not translated or len(translated) < 2:
137
- translated = "❌ Unable to translate. Try a clearer sentence."
138
 
139
- except Exception as e:
140
- translated = f"❌ Translation failed: {str(e)}"
 
 
 
141
 
142
- store_memory(text, translated)
143
 
144
- return translated, src_lang, memory
145
 
146
  # ==============================
147
  # UI
148
  # ==============================
149
- st.set_page_config(page_title="AI Translator", layout="wide")
150
-
151
- st.title("🌍 AI Translator with Voice (Gemma 7B)")
152
-
153
- tab1, tab2 = st.tabs(["📝 Text Input", "🎤 Voice Input"])
154
-
155
- input_text = ""
156
 
157
- # TEXT INPUT
158
- with tab1:
159
- input_text = st.text_area("Enter text", height=150)
160
 
161
- # VOICE INPUT
162
- with tab2:
163
- audio_file = st.file_uploader("Upload audio (wav/mp3)", type=["wav", "mp3"])
164
 
165
- if audio_file:
166
- st.audio(audio_file)
167
-
168
- if st.button("Convert Speech to Text"):
169
- with st.spinner("Processing audio..."):
170
- input_text = speech_to_text(audio_file)
171
- st.success("Recognized Text:")
172
- st.write(input_text)
173
-
174
- # TARGET LANGUAGE
175
- target_lang = st.selectbox(
176
- "Target Language",
177
- ["English", "Tamil", "Hindi", "French", "Arabic", "Spanish", "German"]
178
- )
179
-
180
- # TRANSLATE
181
  if st.button("Translate"):
182
  if not input_text.strip():
183
- st.warning("Please provide input text or audio")
184
  else:
185
  with st.spinner("Translating..."):
186
- output, src_lang, memory = translate(input_text, target_lang)
187
 
188
  st.success("✅ Translation")
189
  st.write(output)
190
 
191
- st.info(f"Detected Language: {src_lang}")
192
-
193
- if memory:
194
- st.caption(f"💡 Similar past translation: {memory}")
 
1
  import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  from langdetect import detect
 
 
 
 
 
 
5
 
6
  # ==============================
7
+ # LOAD MODEL (ONLY ONCE)
8
  # ==============================
9
+ @st.cache_resource
10
+ def load_model():
11
+ tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
12
+ model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
13
+ return tokenizer, model
14
 
15
+ tokenizer, model = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # ==============================
18
+ # LANGUAGE MAP
19
  # ==============================
20
+ lang_map = {
21
+ "English": "eng_Latn",
22
+ "Tamil": "tam_Taml",
23
+ "Hindi": "hin_Deva",
24
+ "French": "fra_Latn",
25
+ "Arabic": "arb_Arab",
26
+ "Spanish": "spa_Latn",
27
+ "German": "deu_Latn"
28
+ }
29
+
30
+ detect_map = {
31
+ "en": "eng_Latn",
32
+ "ta": "tam_Taml",
33
+ "hi": "hin_Deva",
34
+ "fr": "fra_Latn",
35
+ "ar": "arb_Arab",
36
+ "es": "spa_Latn",
37
+ "de": "deu_Latn"
38
+ }
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # ==============================
41
  # TRANSLATION FUNCTION
42
  # ==============================
43
  def translate(text, target_lang):
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  try:
46
+ detected = detect(text)
47
+ src_lang = detect_map.get(detected, "eng_Latn")
48
+ except:
49
+ src_lang = "eng_Latn"
50
+
51
+ tgt_lang = lang_map[target_lang]
52
 
53
+ tokenizer.src_lang = src_lang
54
 
55
+ encoded = tokenizer(text, return_tensors="pt")
 
 
56
 
57
+ generated_tokens = model.generate(
58
+ **encoded,
59
+ forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
60
+ max_length=200
61
+ )
62
 
63
+ translated = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
64
 
65
+ return translated, src_lang
66
 
67
  # ==============================
68
  # UI
69
  # ==============================
70
+ st.title("🌍 NLLB Translator (Transformers)")
 
 
 
 
 
 
71
 
72
+ input_text = st.text_area("Enter text")
 
 
73
 
74
+ target_lang = st.selectbox("Target Language", list(lang_map.keys()))
 
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  if st.button("Translate"):
77
  if not input_text.strip():
78
+ st.warning("Enter text")
79
  else:
80
  with st.spinner("Translating..."):
81
+ output, src_lang = translate(input_text, target_lang)
82
 
83
  st.success("✅ Translation")
84
  st.write(output)
85
 
86
+ st.info(f"Detected Language: {src_lang}")