ayush2917 commited on
Commit
486c63e
·
verified ·
1 Parent(s): e6c70b0

Update src/generation.py

Browse files
Files changed (1) hide show
  1. src/generation.py +21 -26
src/generation.py CHANGED
@@ -1,29 +1,24 @@
1
- import numpy as np
2
- from sentence_transformers import SentenceTransformer
3
- import json
4
- import os
5
- from typing import List, Dict
6
 
7
- class DocumentRetriever:
8
- def __init__(self, model_name='all-MiniLM-L6-v2'):
9
- self.model = SentenceTransformer(model_name)
10
- self.documents = self._load_documents()
11
- self.doc_embeddings = self._embed_documents()
12
 
13
- def _load_documents(self) -> List[Dict]:
14
- with open('data/rupeia_document.json', 'r') as f:
15
- return json.load(f)
 
 
 
 
 
 
16
 
17
- def _embed_documents(self) -> np.ndarray:
18
- texts = [doc['content'] for doc in self.documents]
19
- return self.model.encode(texts)
20
-
21
- def retrieve(self, query: str, top_k: int = 3) -> List[Dict]:
22
- query_embedding = self.model.encode(query)
23
- scores = np.dot(self.doc_embeddings, query_embedding)
24
- top_indices = np.argsort(scores)[-top_k:][::-1]
25
- return [self.documents[i] for i in top_indices]
26
-
27
- def retrieve_relevant_documents(query: str) -> List[Dict]:
28
- retriever = DocumentRetriever()
29
- return retriever.retrieve(query)
 
1
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
2
+ import torch
 
 
 
3
 
4
+ class ResponseGenerator:
5
+ def __init__(self, model_name='distilgpt2'):
6
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
8
+ self.tokenizer.pad_token = self.tokenizer.eos_token
9
 
10
+ def generate(self, prompt: str, max_length: int = 150) -> str:
11
+ inputs = self.tokenizer(prompt, return_tensors="pt")
12
+ outputs = self.model.generate(
13
+ inputs.input_ids,
14
+ max_length=max_length,
15
+ num_return_sequences=1,
16
+ pad_token_id=self.tokenizer.eos_token_id
17
+ )
18
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
19
 
20
+ def generate_response(query: str, context: list) -> str:
21
+ generator = ResponseGenerator()
22
+ context_str = "\n".join([doc['content'] for doc in context])
23
+ prompt = f"Question: {query}\nContext: {context_str}\nAnswer:"
24
+ return generator.generate(prompt)