Synaptyx / modules /qa_module.py
cryogenic22's picture
Create modules/qa_module.py
d5ad9e6 verified
# modules/qa_module.py
from transformers import pipeline
from typing import Dict, List
import torch
class EnhancedQAModule:
def __init__(
self,
model_name: str = "HuggingFaceH4/zephyr-7b-beta",
device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
self.model = pipeline(
"question-answering",
model=model_name,
device=device,
model_kwargs={"torch_dtype": torch.float16 if device == "cuda" else torch.float32}
)
self.prompt_template = """
<|system|>
Answer the question based on the provided context. Be concise and specific.
If the answer cannot be found in the context, say so.
</s>
<|user|>
Context:
{context}
Question: {question}
</s>
<|assistant|>
"""
async def process(self, query: str, context_docs: List[Dict]) -> Dict:
# Combine context documents
context = "\n".join([f"[{doc['metadata']['source']}]: {doc['content']}"
for doc in context_docs])
# Format prompt
prompt = self.prompt_template.format(
context=context,
question=query
)
# Generate answer
response = self.model(
question=query,
context=context,
max_length=200,
num_beams=4,
temperature=0.7
)
return {
"answer": response["answer"],
"confidence": response["score"],
"sources": [doc["metadata"]["source"] for doc in context_docs]
}