abdullahtahir commited on
Commit
9ca09c4
·
verified ·
1 Parent(s): 2f1ff1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -39
app.py CHANGED
@@ -10,7 +10,9 @@ from langchain.prompts import PromptTemplate
10
  from transformers import AutoModelForSeq2SeqLM, pipeline, AutoTokenizer
11
  import torch
12
 
13
- # Check if about_me.txt exists, create a sample if not
 
 
14
  if not os.path.exists("about_me.txt"):
15
  with open("about_me.txt", "w") as f:
16
  f.write("""
@@ -18,37 +20,44 @@ if not os.path.exists("about_me.txt"):
18
  This is a sample portfolio text. Please replace this with your actual portfolio content.
19
  """)
20
 
21
- # Load data
 
 
22
  try:
23
  loader = TextLoader("about_me.txt")
24
  docs = loader.load()
25
  except Exception as e:
26
  print(f"Error loading document: {e}")
27
- # Create fallback document
28
  from langchain.schema import Document
29
  docs = [Document(page_content="Hello! I am a portfolio chatbot ready to help you.")]
30
 
31
- # Split documents
 
 
32
  text_splitter = RecursiveCharacterTextSplitter(
33
- chunk_size=200, # Reduced chunk size
34
- chunk_overlap=30
35
  )
36
  split_docs = text_splitter.split_documents(docs)
37
 
38
- # Initialize embeddings and vector store
 
 
39
  print("Loading embeddings...")
40
  embedding_model = HuggingFaceEmbeddings(
41
  model_name="sentence-transformers/all-MiniLM-L6-v2",
42
- model_kwargs={'device': 'cpu'} # Force CPU usage
43
  )
44
 
45
  print("Creating vector database...")
46
  db = FAISS.from_documents(split_docs, embedding_model)
47
 
48
- # Load smaller model with better error handling
 
 
49
  print("Loading language model...")
50
- model_id = "google/flan-t5-small" # Changed to smaller model
51
- device = "cpu" # Force CPU usage for stability
52
 
53
  try:
54
  tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -57,77 +66,92 @@ try:
57
  torch_dtype=torch.float32,
58
  device_map="auto" if torch.cuda.is_available() else None
59
  )
60
-
61
  pipe = pipeline(
62
  "text2text-generation",
63
  model=model,
64
  tokenizer=tokenizer,
65
- max_new_tokens=128, # Only use max_new_tokens to avoid warning
66
  truncation=True,
67
  device=0 if torch.cuda.is_available() else -1
68
  )
69
-
70
  llm = HuggingFacePipeline(pipeline=pipe)
71
  print("Model loaded successfully!")
72
-
73
  except Exception as e:
74
  print(f"Error loading model: {e}")
75
- # Fallback to a simpler setup
76
- from langchain_community.llms import HuggingFacePipeline
77
  pipe = pipeline(
78
  "text-generation",
79
  model="microsoft/DialoGPT-medium",
80
  max_length=200,
81
- device=-1 # CPU only
82
  )
83
  llm = HuggingFacePipeline(pipeline=pipe)
84
 
85
- # Create RetrievalQA chain
 
 
86
  custom_prompt = PromptTemplate(
87
- template="Based on the following context, answer the question concisely:\n\nContext: {context}\n\nQuestion: {question}\n\nAnswer:",
 
 
 
 
 
 
88
  input_variables=["context", "question"]
89
  )
90
 
 
 
 
91
  qa_chain = RetrievalQA.from_chain_type(
92
  llm=llm,
93
  chain_type="stuff",
94
- retriever=db.as_retriever(search_kwargs={"k": 1}), # Reduced to 1 document
95
  chain_type_kwargs={"prompt": custom_prompt},
96
  return_source_documents=False
97
  )
98
 
 
 
 
99
  def ask_bot_alternative(question):
100
- """Enhanced chatbot function with better error handling"""
101
  try:
102
- if not question or question.strip() == "":
103
  return "Please ask me a question about the portfolio!"
104
-
105
- # Limit input length
106
  question = question[:500]
107
-
108
- print(f"Processing question: {question}")
 
 
 
 
 
 
 
 
109
  response = qa_chain.invoke({"query": question})
110
-
111
- # Extract answer with multiple fallbacks
112
  if isinstance(response, dict):
113
  answer = response.get("result") or response.get("answer") or str(response)
114
  else:
115
  answer = str(response)
116
-
117
- # Clean and limit output
118
  answer = answer.strip()
119
  if len(answer) > 1000:
120
  answer = answer[:1000] + "..."
121
-
122
- return answer if answer else "I couldn't generate a response. Please try rephrasing your question."
123
-
124
  except Exception as e:
125
  print(f"Error in ask_bot_alternative: {e}")
126
- return f"Sorry, I encountered an error: {str(e)[:200]}. Please try again with a different question."
127
-
128
- # Create Gradio interface with better configuration
129
- print("Starting Gradio interface...")
130
 
 
 
 
131
  iface = gr.Interface(
132
  fn=ask_bot_alternative,
133
  inputs=gr.Textbox(
@@ -156,4 +180,4 @@ if __name__ == "__main__":
156
  server_name="0.0.0.0",
157
  server_port=7860,
158
  show_error=True
159
- )
 
10
  from transformers import AutoModelForSeq2SeqLM, pipeline, AutoTokenizer
11
  import torch
12
 
13
+ # -------------------------------
14
+ # 1. Ensure about_me.txt exists
15
+ # -------------------------------
16
  if not os.path.exists("about_me.txt"):
17
  with open("about_me.txt", "w") as f:
18
  f.write("""
 
20
  This is a sample portfolio text. Please replace this with your actual portfolio content.
21
  """)
22
 
23
+ # -------------------------------
24
+ # 2. Load data
25
+ # -------------------------------
26
  try:
27
  loader = TextLoader("about_me.txt")
28
  docs = loader.load()
29
  except Exception as e:
30
  print(f"Error loading document: {e}")
 
31
  from langchain.schema import Document
32
  docs = [Document(page_content="Hello! I am a portfolio chatbot ready to help you.")]
33
 
34
+ # -------------------------------
35
+ # 3. Split documents into chunks
36
+ # -------------------------------
37
  text_splitter = RecursiveCharacterTextSplitter(
38
+ chunk_size=500, # Larger chunk size for better context
39
+ chunk_overlap=50
40
  )
41
  split_docs = text_splitter.split_documents(docs)
42
 
43
+ # -------------------------------
44
+ # 4. Create embeddings & FAISS DB
45
+ # -------------------------------
46
  print("Loading embeddings...")
47
  embedding_model = HuggingFaceEmbeddings(
48
  model_name="sentence-transformers/all-MiniLM-L6-v2",
49
+ model_kwargs={'device': 'cpu'}
50
  )
51
 
52
  print("Creating vector database...")
53
  db = FAISS.from_documents(split_docs, embedding_model)
54
 
55
+ # -------------------------------
56
+ # 5. Load language model
57
+ # -------------------------------
58
  print("Loading language model...")
59
+ model_id = "google/flan-t5-base" # More capable than small
60
+ device = "cpu"
61
 
62
  try:
63
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
66
  torch_dtype=torch.float32,
67
  device_map="auto" if torch.cuda.is_available() else None
68
  )
69
+
70
  pipe = pipeline(
71
  "text2text-generation",
72
  model=model,
73
  tokenizer=tokenizer,
74
+ max_new_tokens=128,
75
  truncation=True,
76
  device=0 if torch.cuda.is_available() else -1
77
  )
78
+
79
  llm = HuggingFacePipeline(pipeline=pipe)
80
  print("Model loaded successfully!")
81
+
82
  except Exception as e:
83
  print(f"Error loading model: {e}")
 
 
84
  pipe = pipeline(
85
  "text-generation",
86
  model="microsoft/DialoGPT-medium",
87
  max_length=200,
88
+ device=-1
89
  )
90
  llm = HuggingFacePipeline(pipeline=pipe)
91
 
92
+ # -------------------------------
93
+ # 6. Custom Prompt
94
+ # -------------------------------
95
  custom_prompt = PromptTemplate(
96
+ template=(
97
+ "Answer the question using only the provided context. "
98
+ "If the answer is not in the context, say you don't know.\n\n"
99
+ "Question: {question}\n\n"
100
+ "Context: {context}\n\n"
101
+ "Answer:"
102
+ ),
103
  input_variables=["context", "question"]
104
  )
105
 
106
+ # -------------------------------
107
+ # 7. Create RetrievalQA chain
108
+ # -------------------------------
109
  qa_chain = RetrievalQA.from_chain_type(
110
  llm=llm,
111
  chain_type="stuff",
112
+ retriever=db.as_retriever(search_kwargs={"k": 3}), # Fetch more context
113
  chain_type_kwargs={"prompt": custom_prompt},
114
  return_source_documents=False
115
  )
116
 
117
+ # -------------------------------
118
+ # 8. Ask function with debug logs
119
+ # -------------------------------
120
  def ask_bot_alternative(question):
 
121
  try:
122
+ if not question.strip():
123
  return "Please ask me a question about the portfolio!"
124
+
 
125
  question = question[:500]
126
+ print(f"\nProcessing question: {question}")
127
+
128
+ # Retrieve and log context
129
+ retriever = db.as_retriever(search_kwargs={"k": 3})
130
+ context_docs = retriever.get_relevant_documents(question)
131
+ print("\n--- Retrieved Context ---")
132
+ for i, d in enumerate(context_docs, 1):
133
+ print(f"[Doc {i}] {d.page_content[:200]}...\n")
134
+
135
+ # Get answer from chain
136
  response = qa_chain.invoke({"query": question})
 
 
137
  if isinstance(response, dict):
138
  answer = response.get("result") or response.get("answer") or str(response)
139
  else:
140
  answer = str(response)
141
+
 
142
  answer = answer.strip()
143
  if len(answer) > 1000:
144
  answer = answer[:1000] + "..."
145
+
146
+ return answer or "I couldn't find an answer in the portfolio content."
147
+
148
  except Exception as e:
149
  print(f"Error in ask_bot_alternative: {e}")
150
+ return f"Sorry, I encountered an error: {str(e)[:200]}"
 
 
 
151
 
152
+ # -------------------------------
153
+ # 9. Gradio Interface
154
+ # -------------------------------
155
  iface = gr.Interface(
156
  fn=ask_bot_alternative,
157
  inputs=gr.Textbox(
 
180
  server_name="0.0.0.0",
181
  server_port=7860,
182
  show_error=True
183
+ )