| from transformers import pipeline | |
| def inference(text, model, tokenizer, args={}): | |
| generator = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device_map="auto" | |
| ) | |
| # Default parameters that can be overridden by args | |
| params = { | |
| "max_new_tokens": 256, | |
| "temperature": 0.7, | |
| "top_p": 0.9, | |
| "top_k": 50, | |
| "do_sample": True, | |
| "repetition_penalty": 1.1 | |
| } | |
| # Update with any user-provided parameters | |
| params.update(args) | |
| # Run generation | |
| result = generator(text, **params) | |
| return result[0]["generated_text"] | |