pradeepsengarr commited on
Commit
554fa87
·
verified ·
1 Parent(s): 63609d2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from sentence_transformers import SentenceTransformer
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSeq2SeqLM, BitsAndBytesConfig
4
+ import faiss
5
+ import numpy as np
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ import fitz
8
+ import os
9
+ import torch
10
+
11
+ # --- Global Variables ---
12
+ index = None
13
+ doc_texts = []
14
+ hf_token = os.environ.get("HF_TOKEN") # Get the Hugging Face token
15
+
16
+ # Language Codes for given languages
17
+ lang_map = {
18
+ "English": "eng_Latn",
19
+ "Hindi": "hin_Deva",
20
+ "Marathi": "mar_Deva",
21
+ "Punjabi": "pan_Guru"
22
+ }
23
+
24
+ # --- Model Loading (will be loaded once on Space startup) ---
25
+
26
+ # For Embedding - Using a smaller, more CPU-friendly SentenceTransformer model
27
+ # This model is generally small enough that quantization isn't critically needed for it.
28
+ embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", token=hf_token)
29
+
30
+ # For LLM - Using "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
31
+ llm_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
32
+ tokenizer = AutoTokenizer.from_pretrained(llm_model_id, token=hf_token)
33
+
34
+ # --- Quantization Configuration ---
35
+ # Choose one of the following quantization methods based on your needs and resources:
36
+
37
+ # Option 1: 8-bit quantization (generally good balance of performance and memory)
38
+ # Requires `bitsandbytes` library: pip install bitsandbytes accelerate
39
+ quantization_config = BitsAndBytesConfig(
40
+ load_in_8bit=True,
41
+ bnb_8bit_compute_dtype=torch.float16 # Use float16 for compute if possible
42
+ )
43
+
44
+ # Option 2: 4-bit quantization (most aggressive memory reduction, potential small accuracy hit)
45
+ # Requires `bitsandbytes` library: pip install bitsandbytes accelerate
46
+ # quantization_config = BitsAndBytesConfig(
47
+ # load_in_4bit=True,
48
+ # bnb_4bit_quant_type="nf4", # NormalFloat 4-bit
49
+ # bnb_4bit_compute_dtype=torch.float16, # Use float16 for compute if possible
50
+ # bnb_4bit_use_double_quant=True, # Double quantization for slightly better precision
51
+ # )
52
+
53
+ # Load the LLM with quantization
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ llm_model_id,
56
+ quantization_config=quantization_config, # Apply the quantization config
57
+ device_map="auto", # Automatically places model parts, often on CPU for 8bit/4bit on CPU-only Spaces
58
+ token=hf_token
59
+ )
60
+
61
+ llm = pipeline(
62
+ "text-generation",
63
+ model=model,
64
+ tokenizer=tokenizer,
65
+ max_new_tokens=300,
66
+ do_sample=True,
67
+ temperature=0.7,
68
+ )
69
+
70
+ # Load a smaller FB Translation Model
71
+ # NLLB-200M is still relatively big. Quantizing it can be tricky for Seq2Seq models
72
+ # with `bitsandbytes` directly for generation quality. If OOM issues persist,
73
+ # consider a much smaller NLLB variant, or a different approach for translation.
74
+ nllb_id = "facebook/nllb-200-distilled-600M" # This model is 600M params, can still be large
75
+ nllb_tokenizer = AutoTokenizer.from_pretrained(nllb_id)
76
+ nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
77
+ nllb_id,
78
+ # For NLLB, direct bitsandbytes quantization might need more testing for quality.
79
+ # If you encounter OOM, uncomment below lines for 8-bit if compatible and test:
80
+ # quantization_config=BitsAndBytesConfig(load_in_8bit=True),
81
+ device_map="auto",
82
+ token=hf_token
83
+ )
84
+ translator = pipeline("translation", model=nllb_model, tokenizer=nllb_tokenizer)
85
+
86
+ # --- Functions ---
87
+
88
+ # Extract PDF text
89
+ def extract_text_from_pdf(file_path):
90
+ text = ""
91
+ doc = fitz.open(file_path)
92
+ for page in doc:
93
+ text += page.get_text()
94
+ return text
95
+
96
+ # Upload data file handler
97
+ def process_file(file):
98
+ global index, doc_texts
99
+
100
+ if file is None:
101
+ return "Please upload a file to process.", gr.Dropdown.update(choices=["English", "Hindi", "Marathi", "Punjabi"], value="English", interactive=False), gr.Textbox.update(interactive=False), gr.Button.update(interactive=False)
102
+
103
+ filename = file.name
104
+ if filename.endswith(".pdf"):
105
+ text = extract_text_from_pdf(file.name)
106
+ elif filename.endswith(".txt"):
107
+ with open(file.name, "r", encoding="utf-8") as f:
108
+ text = f.read()
109
+ else:
110
+ return "Upload the correct files (PDF or TXT).", gr.Dropdown.update(choices=["English", "Hindi", "Marathi", "Punjabi"], value="English", interactive=False), gr.Textbox.update(interactive=False), gr.Button.update(interactive=False)
111
+
112
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
113
+ doc_texts = text_splitter.split_text(text)
114
+
115
+ # Ensure embeddings are float32 for FAISS if `bnb_8bit_compute_dtype` changes it
116
+ embeddings = embed_model.encode(doc_texts).astype(np.float32)
117
+ dim = embeddings.shape[1]
118
+ index = faiss.IndexFlatL2(dim)
119
+ index.add(np.array(embeddings))
120
+
121
+ return "Files uploaded and processed successfully!", gr.Dropdown.update(interactive=True), gr.Textbox.update(interactive=True), gr.Button.update(interactive=True)
122
+
123
+ # Retrieve context using FAISS
124
+ def get_context(question, k=3):
125
+ question_embedding = embed_model.encode([question]).astype(np.float32)
126
+ _, I = index.search(np.array(question_embedding), k)
127
+ return "\n".join([doc_texts[i] for i in I[0]])
128
+
129
+ # Answers with the Translation Option
130
+ def generate_answer(question, lang_choice):
131
+ if index is None:
132
+ return "Please upload and process a file first."
133
+
134
+ context = get_context(question)
135
+ # Using chat template to ensure proper formatting for TinyLlama
136
+ messages = [
137
+ {"role": "system", "content": "You are a helpful assistant. Answer strictly based on the context."},
138
+ {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {question}"}
139
+ ]
140
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
141
+
142
+
143
+ try:
144
+ result = llm(prompt)
145
+ # Extract the answer from the generated text, often it's after the last "Assistant" turn
146
+ # This can be tricky with conversational models, you might need to adjust extraction logic
147
+ # based on exact model output.
148
+ generated_text = result[0]['generated_text']
149
+ # A common way to get the response after the final user turn:
150
+ answer = generated_text.split("assistant\n")[-1].strip()
151
+ # For TinyLlama-1.1B-Chat-v1.0, it might be safer to parse the entire output or use `max_new_tokens` carefully
152
+ # to ensure it doesn't repeat the prompt too much.
153
+
154
+ if lang_choice != "English":
155
+ src_lang = "eng_Latn"
156
+ tgt_lang = lang_map.get(lang_choice, "eng_Latn")
157
+ translated = translator(answer, src_lang=src_lang, tgt_lang=tgt_lang)
158
+ return translated[0]['translation_text']
159
+ else:
160
+ return answer
161
+
162
+ except Exception as e:
163
+ return f"Error generating answer: {str(e)}"
164
+
165
+ # --- Gradio UI ---
166
+ with gr.Blocks(title="Multilingual RAG Chatbot with Quantization") as demo:
167
+ gr.Markdown(
168
+ """
169
+ # Multilingual RAG Chatbot
170
+ Upload your PDF or TXT file, then ask questions. The chatbot will retrieve relevant information
171
+ and generate an answer, which can then be translated into your chosen language.
172
+ """
173
+ )
174
+
175
+ with gr.Row():
176
+ with gr.Column():
177
+ file_input = gr.File(label="1. Upload Document (PDF or TXT)", file_types=[".txt", ".pdf"])
178
+ upload_status = gr.Textbox(label="Processing Status", interactive=False, placeholder="No file uploaded yet.")
179
+ upload_button = gr.Button("Process Document") # Explicit button to trigger processing
180
+ with gr.Column():
181
+ gr.Markdown("---") # Visual separator
182
+ gr.Markdown("### 2. Ask a Question and Get Answer")
183
+ question_box = gr.Textbox(label="Your Question", placeholder="e.g., What is the main topic of the document?")
184
+ lang_dropdown = gr.Dropdown(
185
+ label="Output Language",
186
+ choices=["English", "Hindi", "Marathi", "Punjabi"],
187
+ value="English",
188
+ interactive=False # Initially disable until file is processed
189
+ )
190
+ generate_button = gr.Button("Generate Answer") # Explicit button for generation
191
+ answer_box = gr.Textbox(label="Answer", interactive=False, lines=5, placeholder="The answer will appear here...")
192
+
193
+ # Event handling
194
+ upload_button.click(
195
+ fn=process_file,
196
+ inputs=file_input,
197
+ outputs=[upload_status, lang_dropdown, question_box, generate_button] # Enable other components on success
198
+ )
199
+ generate_button.click(
200
+ fn=generate_answer,
201
+ inputs=[question_box, lang_dropdown],
202
+ outputs=answer_box
203
+ )
204
+ # Also allow 'Enter' key for question box
205
+ question_box.submit(
206
+ fn=generate_answer,
207
+ inputs=[question_box, lang_dropdown],
208
+ outputs=answer_box
209
+ )
210
+
211
+
212
+ demo.launch()