jake2004 commited on
Commit
01a6a25
Β·
verified Β·
1 Parent(s): c8f0784

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
4
+
5
+ # Load RAG model
6
+ model_name = "facebook/rag-sequence-nq"
7
+ tokenizer = RagTokenizer.from_pretrained(model_name)
8
+ retriever = RagRetriever.from_pretrained(model_name, index_name="exact", use_dummy_dataset=True)
9
+ model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)
10
+
11
+ # Function to process uploaded document
12
+ def process_file(file):
13
+ if file is None:
14
+ return "Please upload a document."
15
+
16
+ file_text = file.decode("utf-8")
17
+ return file_text
18
+
19
+ # Function to answer questions using RAG
20
+ def answer_question(document, question):
21
+ if not document.strip():
22
+ return "Please provide document content."
23
+
24
+ inputs = tokenizer(question, document, return_tensors="pt", truncation=True)
25
+ with torch.no_grad():
26
+ generated = model.generate(**inputs)
27
+ answer = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
28
+
29
+ return answer
30
+
31
+ # Gradio UI
32
+ with gr.Blocks() as app:
33
+ gr.Markdown("# πŸ“„ Advanced RAG NLP Document Editor")
34
+
35
+ # File Uploader
36
+ file_input = gr.File(label="Upload Document (TXT only)", type="binary")
37
+ file_output = gr.Textbox(label="Extracted Text", lines=10)
38
+
39
+ file_input.change(process_file, inputs=file_input, outputs=file_output)
40
+
41
+ # Question Answering
42
+ question_input = gr.Textbox(label="Ask a Question")
43
+ answer_output = gr.Textbox(label="Answer", lines=2)
44
+
45
+ submit_btn = gr.Button("Get Answer")
46
+ submit_btn.click(answer_question, inputs=[file_output, question_input], outputs=answer_output)
47
+
48
+ # Launch in Hugging Face Spaces
49
+ app.launch()