Update buffalo_rag/model/rag.py
Browse files- buffalo_rag/model/rag.py +3 -21
buffalo_rag/model/rag.py
CHANGED
|
@@ -31,11 +31,9 @@ class BuffaloRAG:
|
|
| 31 |
query: str,
|
| 32 |
k: int = 5,
|
| 33 |
filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
| 34 |
-
"""Retrieve relevant chunks for a query."""
|
| 35 |
return self.vector_store.hybrid_search(query, k=k, filter_categories=filter_categories)
|
| 36 |
|
| 37 |
def format_context(self, results: List[Dict[str, Any]]) -> str:
|
| 38 |
-
"""Concatenate retrieved passages into context."""
|
| 39 |
ctx = []
|
| 40 |
for i, r in enumerate(results, start=1):
|
| 41 |
c = r["chunk"]
|
|
@@ -47,7 +45,6 @@ class BuffaloRAG:
|
|
| 47 |
return "\n".join(ctx)
|
| 48 |
|
| 49 |
def generate_response(self, query: str, context: str) -> str:
|
| 50 |
-
"""Generate response using the language model with error handling."""
|
| 51 |
prompt = f"""You are a friendly and professional counselor for international students at the University at Buffalo. Respond to the student's query in a supportive, detailed, and well-structured manner.
|
| 52 |
|
| 53 |
For your responses:
|
|
@@ -82,24 +79,20 @@ class BuffaloRAG:
|
|
| 82 |
return completion.choices[0].message.content
|
| 83 |
except Exception as e:
|
| 84 |
print(f"Error during generation: {str(e)}")
|
| 85 |
-
# Fallback response
|
| 86 |
return "I'm sorry, I encountered an issue generating a response. Please try asking your question in a different way or contact UB International Student Services directly for assistance."
|
| 87 |
|
| 88 |
def answer(self,
|
| 89 |
query: str,
|
| 90 |
k: int = 5,
|
| 91 |
filter_categories: Optional[List[str]] = None) -> Dict[str, Any]:
|
| 92 |
-
|
| 93 |
-
# Retrieve relevant chunks
|
| 94 |
results = self.retrieve(query, k=k, filter_categories=filter_categories)
|
| 95 |
|
| 96 |
-
|
| 97 |
context = self.format_context(results)
|
| 98 |
|
| 99 |
-
# Generate response
|
| 100 |
response = self.generate_response(query, context)
|
| 101 |
|
| 102 |
-
# Return response and sources
|
| 103 |
return {
|
| 104 |
'query': query,
|
| 105 |
'response': response,
|
|
@@ -111,15 +104,4 @@ class BuffaloRAG:
|
|
| 111 |
}
|
| 112 |
for result in results
|
| 113 |
]
|
| 114 |
-
}
|
| 115 |
-
|
| 116 |
-
# Example usage
|
| 117 |
-
if __name__ == "__main__":
|
| 118 |
-
rag = BuffaloRAG(model_name="1bitLLM/bitnet_b1_58-large")
|
| 119 |
-
response = rag.answer("How do I apply for OPT?")
|
| 120 |
-
|
| 121 |
-
print(f"Query: {response['query']}")
|
| 122 |
-
print(f"Response: {response['response']}")
|
| 123 |
-
print("\nSources:")
|
| 124 |
-
for source in response['sources']:
|
| 125 |
-
print(f"- {source['title']} (Score: {source['score']:.4f})")
|
|
|
|
| 31 |
query: str,
|
| 32 |
k: int = 5,
|
| 33 |
filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
|
|
|
| 34 |
return self.vector_store.hybrid_search(query, k=k, filter_categories=filter_categories)
|
| 35 |
|
| 36 |
def format_context(self, results: List[Dict[str, Any]]) -> str:
|
|
|
|
| 37 |
ctx = []
|
| 38 |
for i, r in enumerate(results, start=1):
|
| 39 |
c = r["chunk"]
|
|
|
|
| 45 |
return "\n".join(ctx)
|
| 46 |
|
| 47 |
def generate_response(self, query: str, context: str) -> str:
|
|
|
|
| 48 |
prompt = f"""You are a friendly and professional counselor for international students at the University at Buffalo. Respond to the student's query in a supportive, detailed, and well-structured manner.
|
| 49 |
|
| 50 |
For your responses:
|
|
|
|
| 79 |
return completion.choices[0].message.content
|
| 80 |
except Exception as e:
|
| 81 |
print(f"Error during generation: {str(e)}")
|
|
|
|
| 82 |
return "I'm sorry, I encountered an issue generating a response. Please try asking your question in a different way or contact UB International Student Services directly for assistance."
|
| 83 |
|
| 84 |
def answer(self,
|
| 85 |
query: str,
|
| 86 |
k: int = 5,
|
| 87 |
filter_categories: Optional[List[str]] = None) -> Dict[str, Any]:
|
| 88 |
+
|
|
|
|
| 89 |
results = self.retrieve(query, k=k, filter_categories=filter_categories)
|
| 90 |
|
| 91 |
+
|
| 92 |
context = self.format_context(results)
|
| 93 |
|
|
|
|
| 94 |
response = self.generate_response(query, context)
|
| 95 |
|
|
|
|
| 96 |
return {
|
| 97 |
'query': query,
|
| 98 |
'response': response,
|
|
|
|
| 104 |
}
|
| 105 |
for result in results
|
| 106 |
]
|
| 107 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|