|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
def load_stacks(): |
|
|
"""Load STACKS model and tokenizer""" |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"gouthamsai78/STACKS", |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
attn_implementation="eager" |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained("gouthamsai78/STACKS") |
|
|
return model, tokenizer |
|
|
|
|
|
def generate_prompt(role, model, tokenizer, temperature=0.8): |
|
|
"""Generate creative prompt for given role""" |
|
|
input_text = f"### Task: Generate a creative prompt for someone acting as {role}\n### Generated Prompt:" |
|
|
inputs = tokenizer(input_text, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=200, |
|
|
temperature=temperature, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.1, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return response[len(input_text):].strip() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
model, tokenizer = load_stacks() |
|
|
|
|
|
|
|
|
roles = ["chef", "detective", "astronaut", "teacher", "artist"] |
|
|
|
|
|
for role in roles: |
|
|
prompt = generate_prompt(role, model, tokenizer) |
|
|
print(f"**{role.title()}**: {prompt}\n") |
|
|
|