abrar-adnan commited on
Commit
6445da3
·
verified ·
1 Parent(s): 97eabab

Update agents/research_agent.py

Browse files
Files changed (1) hide show
  1. agents/research_agent.py +82 -84
agents/research_agent.py CHANGED
@@ -1,85 +1,83 @@
1
- from langchain_ollama import OllamaLLM
2
- from langchain_ollama import ChatOllama
3
- from typing import Dict, List
4
- from langchain_core.documents.base import Document
5
- from config.settings import settings
6
-
7
- class ResearchAgent:
8
- def __init__(self):
9
- """
10
- Initialize the research agent with local Ollama LLM.
11
- """
12
- print("Initializing ResearchAgent with Ollama (local)...")
13
- self.llm = ChatOllama(
14
- base_url=settings.OLLAMA_BASE_URL,
15
- model=settings.OLLAMA_MODEL_RESEARCH,
16
- temperature=0.3,
17
- num_predict=300, # max_tokens equivalent
18
- )
19
- print("Ollama LLM initialized successfully.")
20
-
21
- def sanitize_response(self, response_text: str) -> str:
22
- """
23
- Sanitize the LLM's response by stripping unnecessary whitespace.
24
- """
25
- return response_text.strip()
26
-
27
- def generate_prompt(self, question: str, context: str) -> str:
28
- """
29
- Generate a structured prompt for the LLM to generate a precise and factual answer.
30
- """
31
- prompt = f"""
32
- You are an AI assistant designed to provide precise and factual answers based on the given context.
33
-
34
- **Instructions:**
35
- - Answer the following question using only the provided context.
36
- - Be clear, concise, and factual.
37
- - Return as much information as you can get from the context.
38
-
39
- **Question:** {question}
40
- **Context:**
41
- {context}
42
-
43
- **Provide your answer below:**
44
- """
45
- return prompt
46
-
47
- def generate(self, question: str, documents: List[Document]) -> Dict:
48
- """
49
- Generate an initial answer using the provided documents.
50
- """
51
- print(f"ResearchAgent.generate called with question='{question}' and {len(documents)} documents.")
52
-
53
- # Combine the top document contents into one string
54
- context = "\n\n".join([doc.page_content for doc in documents])
55
- print(f"Combined context length: {len(context)} characters.")
56
-
57
- # Create a prompt for the LLM
58
- prompt = self.generate_prompt(question, context)
59
- print("Prompt created for the LLM.")
60
-
61
- # Call the LLM to generate the answer
62
- try:
63
- print("Sending prompt to Ollama...")
64
- response = self.llm.invoke(prompt)
65
- print("LLM response received.")
66
-
67
- # Extract content from LangChain message
68
- if hasattr(response, 'content'):
69
- llm_response = response.content
70
- else:
71
- llm_response = str(response)
72
-
73
- except Exception as e:
74
- print(f"Error during model inference: {e}")
75
- raise RuntimeError("Failed to generate answer due to a model error.") from e
76
-
77
- # Sanitize the response
78
- draft_answer = self.sanitize_response(llm_response) if llm_response else "I cannot answer this question based on the provided documents."
79
-
80
- print(f"Generated answer: {draft_answer}")
81
-
82
- return {
83
- "draft_answer": draft_answer,
84
- "context_used": context
85
  }
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ from typing import Dict, List
3
+ from langchain_core.documents.base import Document
4
+ from config.settings import settings
5
+ import torch
6
+
7
+ class ResearchAgent:
8
+ def __init__(self):
9
+ """
10
+ Initialize the research agent with local Ollama LLM.
11
+ """
12
+ print("Initializing ResearchAgent with Hugging Face Transformers...")
13
+ model_name = getattr(settings, "HF_MODEL_RESEARCH", "google/flan-t5-large")
14
+
15
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
17
+
18
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ self.model.to(self.device)
20
+
21
+ print(f"Model '{model_name}' initialized successfully on {self.device}.")
22
+
23
+
24
+ def sanitize_response(self, response_text: str) -> str:
25
+ """
26
+ Sanitize the LLM's response by stripping unnecessary whitespace.
27
+ """
28
+ return response_text.strip()
29
+
30
+ def generate_prompt(self, question: str, context: str) -> str:
31
+ """
32
+ Generate a structured prompt for the LLM to generate a precise and factual answer.
33
+ """
34
+ prompt = f"""
35
+ You are an AI assistant designed to provide precise and factual answers based on the given context.
36
+
37
+ **Instructions:**
38
+ - Answer the following question using only the provided context.
39
+ - Be clear, concise, and factual.
40
+ - Return as much information as you can get from the context.
41
+
42
+ **Question:** {question}
43
+ **Context:**
44
+ {context}
45
+
46
+ **Provide your answer below:**
47
+ """
48
+ return prompt
49
+
50
+ def generate(self, question: str, documents: List[Document]) -> Dict:
51
+ """
52
+ Generate an initial answer using the provided documents.
53
+ """
54
+ print(f"ResearchAgent.generate called with question='{question}' and {len(documents)} documents.")
55
+
56
+ # Combine the top document contents into one string
57
+ context = "\n\n".join([doc.page_content for doc in documents])
58
+ print(f"Combined context length: {len(context)} characters.")
59
+
60
+ # Create a prompt for the LLM
61
+ prompt = self.generate_prompt(question, context)
62
+ print("Prompt created for the LLM.")
63
+
64
+ # Call the LLM to generate the answer
65
+ try:
66
+ print("Running inference with Transformers...")
67
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device)
68
+ outputs = self.model.generate(**inputs, max_new_tokens=300, temperature=0.3)
69
+ llm_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
70
+ print("Model response received.")
71
+ except Exception as e:
72
+ print(f"Error during model inference: {e}")
73
+ raise RuntimeError("Failed to generate answer due to a model error.") from e
74
+
75
+ # Sanitize the response
76
+ draft_answer = self.sanitize_response(llm_response) if llm_response else "I cannot answer this question based on the provided documents."
77
+
78
+ print(f"Generated answer: {draft_answer}")
79
+
80
+ return {
81
+ "draft_answer": draft_answer,
82
+ "context_used": context
 
 
83
  }