jk12p commited on
Commit
5add394
Β·
verified Β·
1 Parent(s): 6297f96

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import fitz # PyMuPDF
4
+ from sentence_transformers import SentenceTransformer
5
+ import faiss
6
+ import numpy as np
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+
9
+ # --- CONFIG ---
10
+ HF_TOKEN = "your_huggingface_token_here" # Add your Hugging Face token
11
+
12
+ # Load tokenizer and model with optimizations
13
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ "google/gemma-2b-it",
16
+ torch_dtype=torch.float16, # Use half-precision for less memory
17
+ device_map="auto" # This will place the model on the best device (CPU/GPU)
18
+ )
19
+
20
+ # Load sentence transformer model for embedding generation
21
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
22
+
23
+ # --- UI ---
24
+ st.title("πŸ” RAG App using πŸ€– Gemma 2B")
25
+
26
+ uploaded_file = st.file_uploader("πŸ“„ Upload a PDF or TXT file", type=["pdf", "txt"])
27
+
28
+ # Extract text from file (PDF/TXT)
29
+ def extract_text(file):
30
+ text = ""
31
+ if file.type == "application/pdf":
32
+ doc = fitz.open(stream=file.read(), filetype="pdf")
33
+ for page in doc:
34
+ text += page.get_text()
35
+ elif file.type == "text/plain":
36
+ text = file.read().decode("utf-8")
37
+ return text
38
+
39
+ # Split text into chunks for indexing
40
+ def split_into_chunks(text, chunk_size=500):
41
+ return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
42
+
43
+ # Create FAISS index for fast retrieval
44
+ def create_faiss_index(chunks):
45
+ embeddings = embedder.encode(chunks)
46
+ dim = embeddings.shape[1]
47
+ index = faiss.IndexFlatL2(dim)
48
+ index.add(np.array(embeddings))
49
+ return index, embeddings
50
+
51
+ # Retrieve top-k relevant chunks for the query
52
+ def retrieve_chunks(query, chunks, index, embeddings, k=3):
53
+ query_embedding = embedder.encode([query])
54
+ D, I = index.search(np.array(query_embedding), k)
55
+ return [chunks[i] for i in I[0]]
56
+
57
+ # --- MAIN LOGIC ---
58
+ if uploaded_file:
59
+ st.success("βœ… File uploaded successfully!")
60
+ raw_text = extract_text(uploaded_file)
61
+ chunks = split_into_chunks(raw_text)
62
+
63
+ st.info(f"πŸ“š Document split into {len(chunks)} chunks")
64
+
65
+ index, embeddings = create_faiss_index(chunks)
66
+
67
+ user_question = st.text_input("πŸ’¬ Ask something about the document:")
68
+
69
+ if user_question:
70
+ with st.spinner("Thinking..."):
71
+ context = "\n".join(retrieve_chunks(user_question, chunks, index, embeddings))
72
+
73
+ # Generate response from Gemma 2B
74
+ input_ids = tokenizer.encode(f"Answer the question based on the context below:\n\nContext:\n{context}\n\nQuestion: {user_question}\nAnswer:", return_tensors="pt").to(model.device)
75
+
76
+ with torch.no_grad(): # Disable gradient computation for inference
77
+ outputs = model.generate(input_ids, max_length=512, num_return_sequences=1, temperature=0.7)
78
+
79
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
80
+ answer = generated_text.split("Answer:")[-1].strip()
81
+
82
+ st.markdown("### 🧠 Answer:")
83
+ st.success(answer)