GirishaBuilds01's picture
Update generator.py
e7ce308 verified
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from config import GENERATION_MODEL, MAX_NEW_TOKENS
class Generator:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained(GENERATION_MODEL)
self.model = AutoModelForCausalLM.from_pretrained(
GENERATION_MODEL,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
def generate(self, query, context_chunks):
context = "\n\n".join(context_chunks)
prompt = f"""
You are an ESG expert assistant.
Context:
{context}
Question:
{query}
Provide a concise, factual answer grounded in the context.
"""
inputs = self.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)