ayush2917 commited on
Commit
e0774c5
·
verified ·
1 Parent(s): 34fbe97

Update src/generation.py

Browse files
Files changed (1) hide show
  1. src/generation.py +7 -21
src/generation.py CHANGED
@@ -1,24 +1,10 @@
1
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
2
- import torch
3
 
4
  class ResponseGenerator:
5
  def __init__(self, model_name='distilgpt2'):
6
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- self.model = AutoModelForCausalLM.from_pretrained(model_name)
8
- self.tokenizer.pad_token = self.tokenizer.eos_token
9
-
10
- def generate(self, prompt: str, max_length: int = 150) -> str:
11
- inputs = self.tokenizer(prompt, return_tensors="pt")
12
- outputs = self.model.generate(
13
- inputs.input_ids,
14
- max_length=max_length,
15
- num_return_sequences=1,
16
- pad_token_id=self.tokenizer.eos_token_id
17
- )
18
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
19
-
20
- def generate_response(query: str, context: list) -> str:
21
- generator = ResponseGenerator()
22
- context_str = "\n".join([doc['content'] for doc in context])
23
- prompt = f"Question: {query}\nContext: {context_str}\nAnswer:"
24
- return generator.generate(prompt)
 
1
+ from transformers import pipeline
 
2
 
3
  class ResponseGenerator:
4
  def __init__(self, model_name='distilgpt2'):
5
+ self.generator = pipeline('text-generation', model=model_name)
6
+
7
+ def generate(self, prompt: str, context: list, max_length: int = 150):
8
+ context_str = "\n".join([doc['content'] for doc in context])
9
+ full_prompt = f"Context: {context_str}\nQuestion: {prompt}\nAnswer:"
10
+ return self.generator(full_prompt, max_length=max_length, num_return_sequences=1)[0]['generated_text']