| | from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration |
| | from datasets import load_dataset |
| | import torch |
| | def load_rag_model(): |
| | tokenizer = RagTokenizer.from_pretrained("nklomp/rag-example") |
| | retriever = RagRetriever.from_pretrained("nklomp/rag-example", dataset=load_dataset("your_dataset")) |
| | model = RagTokenForGeneration.from_pretrained("nklomp/rag-example", retriever=retriever) |
| | return tokenizer, model |
| |
|
| | def query_model(tokenizer, model, query): |
| | inputs = tokenizer(query, return_tensors="pt") |
| | with torch.no_grad(): |
| | outputs = model.generate(**inputs) |
| | return tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | |
| | tokenizer, model = load_rag_model() |
| | user_query = "I am looking for companies that can handle a large construction project." |
| | response = query_model(tokenizer, model, user_query) |
| | print(response) |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|