cryogenic22 commited on
Commit
d5ad9e6
·
verified ·
1 Parent(s): 5cbcca5

Create modules/qa_module.py

Browse files
Files changed (1) hide show
  1. modules/qa_module.py +51 -27
modules/qa_module.py CHANGED
@@ -1,35 +1,59 @@
 
 
1
  # modules/qa_module.py
 
2
  from typing import Dict, List
3
- import chroma
4
- from langchain import OpenAI
5
- from core.base_module import AIModule
6
 
7
- class QAModule(AIModule):
8
- def __init__(self, model_name: str = "gpt-3.5-turbo"):
9
- self.model = OpenAI(model_name=model_name)
10
- self.vector_store = chromadb.Client()
11
- self.collection = self.vector_store.create_collection("qa_collection")
12
-
13
- async def process(self, input_data: Dict) -> Dict:
14
- query = input_data.get("query")
15
- results = self.collection.query(
16
- query_texts=[query],
17
- n_results=2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  )
19
- context = results['documents'][0]
20
 
21
- response = self.model.predict(
22
- f"Context: {context}\nQuestion: {query}\nAnswer:"
 
 
 
 
 
23
  )
24
 
25
  return {
26
- "answer": response,
27
- "sources": context
28
- }
29
-
30
- async def get_status(self) -> Dict:
31
- return {"status": "operational", "documents_indexed": len(self.collection)}
32
-
33
- @property
34
- def capabilities(self) -> List[str]:
35
- return ["question-answering", "context-aware-responses"]
 
1
+
2
+
3
  # modules/qa_module.py
4
+ from transformers import pipeline
5
  from typing import Dict, List
6
+ import torch
 
 
7
 
8
+ class EnhancedQAModule:
9
+ def __init__(
10
+ self,
11
+ model_name: str = "HuggingFaceH4/zephyr-7b-beta",
12
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
13
+ ):
14
+ self.model = pipeline(
15
+ "question-answering",
16
+ model=model_name,
17
+ device=device,
18
+ model_kwargs={"torch_dtype": torch.float16 if device == "cuda" else torch.float32}
19
+ )
20
+
21
+ self.prompt_template = """
22
+ <|system|>
23
+ Answer the question based on the provided context. Be concise and specific.
24
+ If the answer cannot be found in the context, say so.
25
+ </s>
26
+ <|user|>
27
+ Context:
28
+ {context}
29
+
30
+ Question: {question}
31
+ </s>
32
+ <|assistant|>
33
+ """
34
+
35
+ async def process(self, query: str, context_docs: List[Dict]) -> Dict:
36
+ # Combine context documents
37
+ context = "\n".join([f"[{doc['metadata']['source']}]: {doc['content']}"
38
+ for doc in context_docs])
39
+
40
+ # Format prompt
41
+ prompt = self.prompt_template.format(
42
+ context=context,
43
+ question=query
44
  )
 
45
 
46
+ # Generate answer
47
+ response = self.model(
48
+ question=query,
49
+ context=context,
50
+ max_length=200,
51
+ num_beams=4,
52
+ temperature=0.7
53
  )
54
 
55
  return {
56
+ "answer": response["answer"],
57
+ "confidence": response["score"],
58
+ "sources": [doc["metadata"]["source"] for doc in context_docs]
59
+ }