Stanley03 commited on
Commit
36f67aa
·
verified ·
1 Parent(s): 6aebbe5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -12
app.py CHANGED
@@ -6,11 +6,12 @@ import faiss
6
  import torch
7
  import numpy as np
8
 
9
- # Load docx content
10
  def load_docx_text(path):
11
  doc = Document(path)
12
  return "\n".join([p.text for p in doc.paragraphs if p.text.strip() != ""])
13
 
 
14
  text_data = load_docx_text("8_laws.docx")
15
 
16
  # Chunk text
@@ -20,25 +21,28 @@ def chunk_text(text, chunk_size=300, overlap=50):
20
 
21
  doc_chunks = chunk_text(text_data)
22
 
23
- # Embedding model and FAISS
24
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
25
  doc_embeddings = embedder.encode(doc_chunks)
 
 
26
  index = faiss.IndexFlatL2(doc_embeddings.shape[1])
27
  index.add(np.array(doc_embeddings))
28
 
29
- # LLM
30
  tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
31
- model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16, device_map="auto")
32
 
33
- # RAG Logic
34
  def retrieve_context(query, k=3):
35
  query_vec = embedder.encode([query])
36
  _, indices = index.search(np.array(query_vec), k)
37
  return [doc_chunks[i] for i in indices[0]]
38
 
39
  def generate_answer(question):
40
- context = "\n".join(retrieve_context(question))
41
- prompt = f"""Use the context below to answer the question.
 
42
 
43
  Context:
44
  {context}
@@ -47,17 +51,24 @@ Question:
47
  {question}
48
 
49
  Answer:"""
50
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
51
- outputs = model.generate(**inputs, max_new_tokens=150)
52
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
53
 
54
- # Gradio App
 
 
 
 
 
 
 
 
 
 
55
  demo = gr.Interface(
56
  fn=generate_answer,
57
  inputs=gr.Textbox(lines=2, placeholder="Ask a question..."),
58
  outputs="text",
59
  title="📘 TinyLLaMA DOCX RAG",
60
- description="Ask questions from the 8 Laws docx file"
61
  )
62
 
63
  demo.launch()
 
6
  import torch
7
  import numpy as np
8
 
9
+ # Load .docx file
10
  def load_docx_text(path):
11
  doc = Document(path)
12
  return "\n".join([p.text for p in doc.paragraphs if p.text.strip() != ""])
13
 
14
+ # Make sure this filename matches the uploaded file
15
  text_data = load_docx_text("8_laws.docx")
16
 
17
  # Chunk text
 
21
 
22
  doc_chunks = chunk_text(text_data)
23
 
24
+ # Embed text
25
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
26
  doc_embeddings = embedder.encode(doc_chunks)
27
+
28
+ # Build FAISS index
29
  index = faiss.IndexFlatL2(doc_embeddings.shape[1])
30
  index.add(np.array(doc_embeddings))
31
 
32
+ # Load TinyLLaMA (CPU safe)
33
  tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
34
+ model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
35
 
36
+ # RAG logic
37
  def retrieve_context(query, k=3):
38
  query_vec = embedder.encode([query])
39
  _, indices = index.search(np.array(query_vec), k)
40
  return [doc_chunks[i] for i in indices[0]]
41
 
42
  def generate_answer(question):
43
+ try:
44
+ context = "\n".join(retrieve_context(question))
45
+ prompt = f"""Use the context below to answer the question.
46
 
47
  Context:
48
  {context}
 
51
  {question}
52
 
53
  Answer:"""
 
 
 
54
 
55
+ print("🧠 Prompt:\n", prompt)
56
+
57
+ inputs = tokenizer(prompt, return_tensors="pt")
58
+ output = model.generate(**inputs, max_new_tokens=150)
59
+ answer = tokenizer.decode(output[0], skip_special_tokens=True)
60
+ return answer
61
+ except Exception as e:
62
+ print("❌ ERROR:", str(e))
63
+ return f"An error occurred: {e}"
64
+
65
+ # Gradio interface
66
  demo = gr.Interface(
67
  fn=generate_answer,
68
  inputs=gr.Textbox(lines=2, placeholder="Ask a question..."),
69
  outputs="text",
70
  title="📘 TinyLLaMA DOCX RAG",
71
+ description="Ask a question about the 8 laws of health"
72
  )
73
 
74
  demo.launch()