ayush2917 commited on
Commit
e6c70b0
·
verified ·
1 Parent(s): a19857d

Update src/generation.py

Browse files
Files changed (1) hide show
  1. src/generation.py +26 -28
src/generation.py CHANGED
@@ -1,31 +1,29 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
2
 
3
- class ResponseGenerator:
4
- def __init__(self, model_name="distilgpt2"):
5
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=True)
6
- # Set a distinct pad token
7
- if self.tokenizer.pad_token is None:
8
- self.tokenizer.pad_token = self.tokenizer.eos_token
9
- self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
10
- self.model = AutoModelForCausalLM.from_pretrained(model_name)
11
- self.prompt_template = """
12
- You are a customer support chatbot for Rupeia, a financial platform. Provide accurate, concise answers about Investments, Goals, Benefits, Gadgets, and News & Blogs. Use the context and history to respond naturally. If unsure, say: "I’m not sure about that. Could you clarify or ask about Rupeia features?"
13
 
14
- Context: {context}
15
- History: {history}
16
- User: {user_input}
17
- Answer: """
18
 
19
- def generate_response(self, user_input, context, history):
20
- history_str = "\n".join([f"User: {h[0]}\nBot: {h[1]}" for h in history[-3:]])
21
- prompt = self.prompt_template.format(context=context, history=history_str, user_input=user_input)
22
- inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
23
- # Ensure attention mask is passed
24
- outputs = self.model.generate(
25
- inputs["input_ids"],
26
- attention_mask=inputs["attention_mask"],
27
- max_length=200,
28
- pad_token_id=self.tokenizer.pad_token_id
29
- )
30
- response = self.tokenizer.decode(outputs[:, inputs["input_ids"].shape[-1]:][0], skip_special_tokens=True).strip()
31
- return response if response else context or "I’m not sure about that. Could you clarify or ask about Rupeia features?"
 
1
+ import numpy as np
2
+ from sentence_transformers import SentenceTransformer
3
+ import json
4
+ import os
5
+ from typing import List, Dict
6
 
7
+ class DocumentRetriever:
8
+ def __init__(self, model_name='all-MiniLM-L6-v2'):
9
+ self.model = SentenceTransformer(model_name)
10
+ self.documents = self._load_documents()
11
+ self.doc_embeddings = self._embed_documents()
 
 
 
 
 
12
 
13
+ def _load_documents(self) -> List[Dict]:
14
+ with open('data/rupeia_document.json', 'r') as f:
15
+ return json.load(f)
 
16
 
17
+ def _embed_documents(self) -> np.ndarray:
18
+ texts = [doc['content'] for doc in self.documents]
19
+ return self.model.encode(texts)
20
+
21
+ def retrieve(self, query: str, top_k: int = 3) -> List[Dict]:
22
+ query_embedding = self.model.encode(query)
23
+ scores = np.dot(self.doc_embeddings, query_embedding)
24
+ top_indices = np.argsort(scores)[-top_k:][::-1]
25
+ return [self.documents[i] for i in top_indices]
26
+
27
+ def retrieve_relevant_documents(query: str) -> List[Dict]:
28
+ retriever = DocumentRetriever()
29
+ return retriever.retrieve(query)