i-dhilip commited on
Commit
cbc1cbc
·
verified ·
1 Parent(s): 76df2dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -67
app.py CHANGED
@@ -5,11 +5,10 @@ import pandas as pd
5
  from datetime import datetime
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
  from langchain_community.llms import HuggingFacePipeline
8
- from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
9
  from langchain.chains import LLMChain
10
  from langchain.agents import Tool
11
  from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
12
- from langchain_community.utilities import TextRequestsWrapper
13
  from langchain_community.embeddings import HuggingFaceEmbeddings
14
  from langchain_community.vectorstores import Chroma
15
 
@@ -30,83 +29,90 @@ pipe = pipeline(
30
  )
31
  llm = HuggingFacePipeline(pipeline=pipe)
32
 
33
- # --- System Message ---
34
- system_prompt = """You are a helpful assistant tasked with answering questions using a set of tools.
35
- Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
36
- FINAL ANSWER: [YOUR FINAL ANSWER].
37
- YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""
38
- system_message_prompt = SystemMessagePromptTemplate.from_template(system_prompt)
39
-
40
- # --- Tools ---
41
  ddg = DuckDuckGoSearchAPIWrapper()
42
- requests_wrapper = TextRequestsWrapper()
43
 
44
- def wiki_search(query):
45
- """Search Wikipedia for a query and return maximum 2 results."""
46
- search_results = ddg.run(query)
47
- return {"wiki_results": search_results}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- def web_search(query):
50
- """Search Tavily for a query and return maximum 3 results."""
51
- search_results = ddg.run(query)
52
- return {"web_results": search_results}
53
 
54
- def arxiv_search(query):
55
- """Search Arxiv for a query and return maximum 3 results."""
56
- url = f"https://export.arxiv.org/api/query?search_query=all:{query}&start=0&max_results=3"
57
- response = requests_wrapper.get(url)
58
- return {"arxiv_results": response.text}
 
 
 
 
 
 
 
 
 
59
 
60
  # --- Chroma DB Setup ---
61
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
62
-
63
  vector_store = Chroma(
64
  embedding_function=embeddings,
65
  persist_directory="./chroma_db"
66
  )
67
 
68
- def create_retriever_tool(query):
69
- """A tool to retrieve similar questions from a vector store."""
70
- similar_question = vector_store.similarity_search(query)
71
- return {"retriever_results": similar_question[0].page_content}
72
-
73
- tools = [
74
- Tool(
75
- name="Wikipedia Search",
76
- func=wiki_search,
77
- description="Search Wikipedia for a query and return maximum 2 results."
78
- ),
79
- Tool(
80
- name="Web Search",
81
- func=web_search,
82
- description="Search Tavily for a query and return maximum 3 results."
83
- ),
84
- Tool(
85
- name="Arxiv Search",
86
- func=arxiv_search,
87
- description="Search Arxiv for a query and return maximum 3 results."
88
- ),
89
- Tool(
90
- name="Retriever",
91
- func=create_retriever_tool,
92
- description="A tool to retrieve similar questions from a vector store."
93
- )
94
- ]
95
-
96
- def create_agent(llm, tools):
97
- """Create an agent with the specified tools."""
98
- prompt = ChatPromptTemplate.from_messages([
99
- system_message_prompt,
100
- HumanMessagePromptTemplate.from_template("{input}")
101
- ])
102
- llm_chain = LLMChain(llm=llm, prompt=prompt)
103
- return llm_chain
104
-
105
- def extract_final_answer(full_response):
106
- """Extract only the final answer from the agent's response."""
107
- if "FINAL ANSWER:" in full_response:
108
- return full_response.split("FINAL ANSWER:")[1].strip()
109
- return full_response.strip()
110
 
111
  def run_and_submit_all(profile: gr.OAuthProfile | None):
112
  """
 
5
  from datetime import datetime
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
  from langchain_community.llms import HuggingFacePipeline
8
+ from langchain.prompts import PromptTemplate
9
  from langchain.chains import LLMChain
10
  from langchain.agents import Tool
11
  from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
 
12
  from langchain_community.embeddings import HuggingFaceEmbeddings
13
  from langchain_community.vectorstores import Chroma
14
 
 
29
  )
30
  llm = HuggingFacePipeline(pipeline=pipe)
31
 
32
+ # --- Tools Setup ---
 
 
 
 
 
 
 
33
  ddg = DuckDuckGoSearchAPIWrapper()
 
34
 
35
+ def enhanced_search(query):
36
+ """Enhanced search combining multiple sources"""
37
+ try:
38
+ # Web search
39
+ web_results = ddg.results(query, 3)
40
+ # Wikipedia search
41
+ wiki_results = ddg.results(f"wikipedia {query}", 2)
42
+ return {
43
+ "web": [r["snippet"] for r in web_results],
44
+ "wikipedia": [r["snippet"] for r in wiki_results]
45
+ }
46
+ except Exception as e:
47
+ print(f"Search error: {e}")
48
+ return {}
49
+
50
+ # --- Prompt Engineering ---
51
+ PROMPT_TEMPLATE = """Use the following context to answer the question.
52
+ If you don't know the answer, say you don't know. Keep answers very short.
53
+
54
+ Context:
55
+ {search_results}
56
+
57
+ Question: {question}
58
 
59
+ Think step by step, then write the final answer starting with FINAL ANSWER:"""
 
 
 
60
 
61
+ prompt = PromptTemplate(
62
+ template=PROMPT_TEMPLATE,
63
+ input_variables=["search_results", "question"]
64
+ )
65
+
66
+ # --- Answer Processing ---
67
+ def process_answer(raw_answer: str) -> str:
68
+ """Extract and clean the final answer"""
69
+ if "FINAL ANSWER:" in raw_answer:
70
+ answer = raw_answer.split("FINAL ANSWER:")[-1].strip()
71
+ answer = answer.split('\n')[0].strip()
72
+ answer = answer[:MAX_ANSWER_LENGTH]
73
+ return answer
74
+ return raw_answer.strip()[:MAX_ANSWER_LENGTH]
75
 
76
  # --- Chroma DB Setup ---
77
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
 
78
  vector_store = Chroma(
79
  embedding_function=embeddings,
80
  persist_directory="./chroma_db"
81
  )
82
 
83
+ # --- Core Agent Logic ---
84
+ def get_agent_response(question: str) -> str:
85
+ """Get agent response with integrated search"""
86
+ try:
87
+ # Step 1: Search for relevant information
88
+ search_results = enhanced_search(question)
89
+
90
+ # Step 2: Format context
91
+ context = []
92
+ if search_results.get("web"):
93
+ context.append("Web results:\n- " + "\n- ".join(search_results["web"]))
94
+ if search_results.get("wikipedia"):
95
+ context.append("Wikipedia results:\n- " + "\n- ".join(search_results["wikipedia"]))
96
+
97
+ # Step 3: Retrieve similar questions
98
+ similar = vector_store.similarity_search(question, k=1)
99
+ if similar:
100
+ context.append(f"Similar question: {similar[0].page_content}")
101
+
102
+ full_context = "\n\n".join(context) if context else "No search results found"
103
+
104
+ # Step 4: Generate answer
105
+ chain = LLMChain(llm=llm, prompt=prompt)
106
+ response = chain.run({
107
+ "search_results": full_context,
108
+ "question": question
109
+ })
110
+
111
+ return process_answer(response)
112
+
113
+ except Exception as e:
114
+ print(f"Agent error: {e}")
115
+ return f"Error processing question: {e}"
 
 
 
 
 
 
 
 
 
116
 
117
  def run_and_submit_all(profile: gr.OAuthProfile | None):
118
  """