pradeepsengarr commited on
Commit
03245d6
Β·
verified Β·
1 Parent(s): 2c3bba0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -207
app.py CHANGED
@@ -1,212 +1,60 @@
1
  import gradio as gr
2
- from sentence_transformers import SentenceTransformer
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSeq2SeqLM, BitsAndBytesConfig
4
  import faiss
5
- import numpy as np
6
- from langchain.text_splitter import RecursiveCharacterTextSplitter
7
- import fitz
8
- import os
9
  import torch
10
-
11
- # --- Global Variables ---
12
- index = None
13
- doc_texts = []
14
- hf_token = os.environ.get("HF_TOKEN") # Get the Hugging Face token
15
-
16
- # Language Codes for given languages
17
- lang_map = {
18
- "English": "eng_Latn",
19
- "Hindi": "hin_Deva",
20
- "Marathi": "mar_Deva",
21
- "Punjabi": "pan_Guru"
22
- }
23
-
24
- # --- Model Loading (will be loaded once on Space startup) ---
25
-
26
- # For Embedding - Using a smaller, more CPU-friendly SentenceTransformer model
27
- # This model is generally small enough that quantization isn't critically needed for it.
28
- embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", token=hf_token)
29
-
30
- # For LLM - Using "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
31
- llm_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
32
- tokenizer = AutoTokenizer.from_pretrained(llm_model_id, token=hf_token)
33
-
34
- # --- Quantization Configuration ---
35
- # Choose one of the following quantization methods based on your needs and resources:
36
-
37
- # Option 1: 8-bit quantization (generally good balance of performance and memory)
38
- # Requires `bitsandbytes` library: pip install bitsandbytes accelerate
39
- quantization_config = BitsAndBytesConfig(
40
- load_in_8bit=True,
41
- bnb_8bit_compute_dtype=torch.float16 # Use float16 for compute if possible
42
- )
43
-
44
- # Option 2: 4-bit quantization (most aggressive memory reduction, potential small accuracy hit)
45
- # Requires `bitsandbytes` library: pip install bitsandbytes accelerate
46
- # quantization_config = BitsAndBytesConfig(
47
- # load_in_4bit=True,
48
- # bnb_4bit_quant_type="nf4", # NormalFloat 4-bit
49
- # bnb_4bit_compute_dtype=torch.float16, # Use float16 for compute if possible
50
- # bnb_4bit_use_double_quant=True, # Double quantization for slightly better precision
51
- # )
52
-
53
- # Load the LLM with quantization
54
- model = AutoModelForCausalLM.from_pretrained(
55
- llm_model_id,
56
- quantization_config=quantization_config, # Apply the quantization config
57
- device_map="auto", # Automatically places model parts, often on CPU for 8bit/4bit on CPU-only Spaces
58
- token=hf_token
59
- )
60
-
61
- llm = pipeline(
62
- "text-generation",
63
- model=model,
64
- tokenizer=tokenizer,
65
- max_new_tokens=300,
66
- do_sample=True,
67
- temperature=0.7,
68
- )
69
-
70
- # Load a smaller FB Translation Model
71
- # NLLB-200M is still relatively big. Quantizing it can be tricky for Seq2Seq models
72
- # with `bitsandbytes` directly for generation quality. If OOM issues persist,
73
- # consider a much smaller NLLB variant, or a different approach for translation.
74
- nllb_id = "facebook/nllb-200-distilled-600M" # This model is 600M params, can still be large
75
- nllb_tokenizer = AutoTokenizer.from_pretrained(nllb_id)
76
- nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
77
- nllb_id,
78
- # For NLLB, direct bitsandbytes quantization might need more testing for quality.
79
- # If you encounter OOM, uncomment below lines for 8-bit if compatible and test:
80
- # quantization_config=BitsAndBytesConfig(load_in_8bit=True),
81
- device_map="auto",
82
- token=hf_token
83
  )
84
- translator = pipeline("translation", model=nllb_model, tokenizer=nllb_tokenizer)
85
-
86
- # --- Functions ---
87
-
88
- # Extract PDF text
89
- def extract_text_from_pdf(file_path):
90
- text = ""
91
- doc = fitz.open(file_path)
92
- for page in doc:
93
- text += page.get_text()
94
- return text
95
-
96
- # Upload data file handler
97
- def process_file(file):
98
- global index, doc_texts
99
-
100
- if file is None:
101
- return "Please upload a file to process.", gr.Dropdown.update(choices=["English", "Hindi", "Marathi", "Punjabi"], value="English", interactive=False), gr.Textbox.update(interactive=False), gr.Button.update(interactive=False)
102
-
103
- filename = file.name
104
- if filename.endswith(".pdf"):
105
- text = extract_text_from_pdf(file.name)
106
- elif filename.endswith(".txt"):
107
- with open(file.name, "r", encoding="utf-8") as f:
108
- text = f.read()
109
- else:
110
- return "Upload the correct files (PDF or TXT).", gr.Dropdown.update(choices=["English", "Hindi", "Marathi", "Punjabi"], value="English", interactive=False), gr.Textbox.update(interactive=False), gr.Button.update(interactive=False)
111
-
112
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
113
- doc_texts = text_splitter.split_text(text)
114
-
115
- # Ensure embeddings are float32 for FAISS if `bnb_8bit_compute_dtype` changes it
116
- embeddings = embed_model.encode(doc_texts).astype(np.float32)
117
- dim = embeddings.shape[1]
118
- index = faiss.IndexFlatL2(dim)
119
- index.add(np.array(embeddings))
120
-
121
- return "Files uploaded and processed successfully!", gr.Dropdown.update(interactive=True), gr.Textbox.update(interactive=True), gr.Button.update(interactive=True)
122
-
123
- # Retrieve context using FAISS
124
- def get_context(question, k=3):
125
- question_embedding = embed_model.encode([question]).astype(np.float32)
126
- _, I = index.search(np.array(question_embedding), k)
127
- return "\n".join([doc_texts[i] for i in I[0]])
128
-
129
- # Answers with the Translation Option
130
- def generate_answer(question, lang_choice):
131
- if index is None:
132
- return "Please upload and process a file first."
133
-
134
- context = get_context(question)
135
- # Using chat template to ensure proper formatting for TinyLlama
136
- messages = [
137
- {"role": "system", "content": "You are a helpful assistant. Answer strictly based on the context."},
138
- {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {question}"}
139
- ]
140
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
141
-
142
-
143
- try:
144
- result = llm(prompt)
145
- # Extract the answer from the generated text, often it's after the last "Assistant" turn
146
- # This can be tricky with conversational models, you might need to adjust extraction logic
147
- # based on exact model output.
148
- generated_text = result[0]['generated_text']
149
- # A common way to get the response after the final user turn:
150
- answer = generated_text.split("assistant\n")[-1].strip()
151
- # For TinyLlama-1.1B-Chat-v1.0, it might be safer to parse the entire output or use `max_new_tokens` carefully
152
- # to ensure it doesn't repeat the prompt too much.
153
-
154
- if lang_choice != "English":
155
- src_lang = "eng_Latn"
156
- tgt_lang = lang_map.get(lang_choice, "eng_Latn")
157
- translated = translator(answer, src_lang=src_lang, tgt_lang=tgt_lang)
158
- return translated[0]['translation_text']
159
- else:
160
- return answer
161
-
162
- except Exception as e:
163
- return f"Error generating answer: {str(e)}"
164
-
165
- # --- Gradio UI ---
166
- with gr.Blocks(title="Multilingual RAG Chatbot with Quantization") as demo:
167
- gr.Markdown(
168
- """
169
- # Multilingual RAG Chatbot
170
- Upload your PDF or TXT file, then ask questions. The chatbot will retrieve relevant information
171
- and generate an answer, which can then be translated into your chosen language.
172
- """
173
- )
174
-
175
- with gr.Row():
176
- with gr.Column():
177
- file_input = gr.File(label="1. Upload Document (PDF or TXT)", file_types=[".txt", ".pdf"])
178
- upload_status = gr.Textbox(label="Processing Status", interactive=False, placeholder="No file uploaded yet.")
179
- upload_button = gr.Button("Process Document") # Explicit button to trigger processing
180
- with gr.Column():
181
- gr.Markdown("---") # Visual separator
182
- gr.Markdown("### 2. Ask a Question and Get Answer")
183
- question_box = gr.Textbox(label="Your Question", placeholder="e.g., What is the main topic of the document?")
184
- lang_dropdown = gr.Dropdown(
185
- label="Output Language",
186
- choices=["English", "Hindi", "Marathi", "Punjabi"],
187
- value="English",
188
- interactive=False # Initially disable until file is processed
189
- )
190
- generate_button = gr.Button("Generate Answer") # Explicit button for generation
191
- answer_box = gr.Textbox(label="Answer", interactive=False, lines=5, placeholder="The answer will appear here...")
192
-
193
- # Event handling
194
- upload_button.click(
195
- fn=process_file,
196
- inputs=file_input,
197
- outputs=[upload_status, lang_dropdown, question_box, generate_button] # Enable other components on success
198
- )
199
- generate_button.click(
200
- fn=generate_answer,
201
- inputs=[question_box, lang_dropdown],
202
- outputs=answer_box
203
- )
204
- # Also allow 'Enter' key for question box
205
- question_box.submit(
206
- fn=generate_answer,
207
- inputs=[question_box, lang_dropdown],
208
- outputs=answer_box
209
- )
210
-
211
 
212
- demo.launch()
 
1
  import gradio as gr
 
 
2
  import faiss
 
 
 
 
3
  import torch
4
+ from sentence_transformers import SentenceTransformer
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
6
+
7
+ # ---------- Load models ----------
8
+ embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
9
+
10
+ gen_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
11
+ gen_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype=torch.float32)
12
+
13
+ # Example: EN->HI
14
+ trans_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
15
+ trans_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
16
+
17
+ # ---------- Sample docs + FAISS index ----------
18
+ documents = [
19
+ "The Taj Mahal is an ivory-white marble mausoleum in India.",
20
+ "ChatGPT is a large language model developed by OpenAI.",
21
+ "RAG combines retrieval-based and generation-based approaches."
22
+ ]
23
+
24
+ doc_embeddings = embed_model.encode(documents, convert_to_tensor=True)
25
+ index = faiss.IndexFlatL2(doc_embeddings.shape[1])
26
+ index.add(doc_embeddings.cpu().numpy())
27
+
28
+ # ---------- RAG Function ----------
29
+ def rag_translate(query, target_lang='hi'):
30
+ query_vec = embed_model.encode([query])
31
+ _, top_indices = index.search(query_vec, k=1)
32
+ retrieved_doc = documents[top_indices[0][0]]
33
+
34
+ prompt = f"Context: {retrieved_doc}\nQuestion: {query}\nAnswer:"
35
+ inputs = gen_tokenizer(prompt, return_tensors="pt")
36
+ outputs = gen_model.generate(**inputs, max_new_tokens=64)
37
+ answer_en = gen_tokenizer.decode(outputs[0], skip_special_tokens=True)
38
+
39
+ # Translate if requested
40
+ if target_lang and target_lang != 'en':
41
+ trans_inputs = trans_tokenizer(answer_en, return_tensors="pt", truncation=True)
42
+ trans_output = trans_model.generate(**trans_inputs)
43
+ translated = trans_tokenizer.decode(trans_output[0], skip_special_tokens=True)
44
+ return f"πŸ” Answer:\n{answer_en}\n\n🌐 Translated:\n{translated}"
45
+
46
+ return f"πŸ” Answer:\n{answer_en}"
47
+
48
+ # ---------- Gradio UI ----------
49
+ iface = gr.Interface(
50
+ fn=rag_translate,
51
+ inputs=[
52
+ gr.Textbox(label="Ask a Question"),
53
+ gr.Dropdown(choices=["en", "hi", "fr", "es"], value="hi", label="Target Language")
54
+ ],
55
+ outputs=gr.Textbox(label="Answer"),
56
+ title="🧠 RAG + 🌍 Translator",
57
+ description="A lightweight RAG system with answer translation. Powered by Phi-2 + MiniLM + Opus MT."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ iface.launch()