Spaces:
Sleeping
Sleeping
| # modules/qa_module.py | |
| from transformers import pipeline | |
| from typing import Dict, List | |
| import torch | |
| class EnhancedQAModule: | |
| def __init__( | |
| self, | |
| model_name: str = "HuggingFaceH4/zephyr-7b-beta", | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
| ): | |
| self.model = pipeline( | |
| "question-answering", | |
| model=model_name, | |
| device=device, | |
| model_kwargs={"torch_dtype": torch.float16 if device == "cuda" else torch.float32} | |
| ) | |
| self.prompt_template = """ | |
| <|system|> | |
| Answer the question based on the provided context. Be concise and specific. | |
| If the answer cannot be found in the context, say so. | |
| </s> | |
| <|user|> | |
| Context: | |
| {context} | |
| Question: {question} | |
| </s> | |
| <|assistant|> | |
| """ | |
| async def process(self, query: str, context_docs: List[Dict]) -> Dict: | |
| # Combine context documents | |
| context = "\n".join([f"[{doc['metadata']['source']}]: {doc['content']}" | |
| for doc in context_docs]) | |
| # Format prompt | |
| prompt = self.prompt_template.format( | |
| context=context, | |
| question=query | |
| ) | |
| # Generate answer | |
| response = self.model( | |
| question=query, | |
| context=context, | |
| max_length=200, | |
| num_beams=4, | |
| temperature=0.7 | |
| ) | |
| return { | |
| "answer": response["answer"], | |
| "confidence": response["score"], | |
| "sources": [doc["metadata"]["source"] for doc in context_docs] | |
| } |