| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import torch |
| from config import GENERATION_MODEL, MAX_NEW_TOKENS |
|
|
|
|
| class Generator: |
| def __init__(self): |
| self.tokenizer = AutoTokenizer.from_pretrained(GENERATION_MODEL) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| GENERATION_MODEL, |
| torch_dtype=torch.float32, |
| low_cpu_mem_usage=True |
| ) |
|
|
|
|
|
|
| def generate(self, query, context_chunks): |
| context = "\n\n".join(context_chunks) |
|
|
| prompt = f""" |
| You are an ESG expert assistant. |
| |
| Context: |
| {context} |
| |
| Question: |
| {query} |
| |
| Provide a concise, factual answer grounded in the context. |
| """ |
|
|
| inputs = self.tokenizer(prompt, return_tensors="pt") |
|
|
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=MAX_NEW_TOKENS, |
| do_sample=False |
| ) |
|
|
| return self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|