File size: 6,431 Bytes
cc60987 627731d cc60987 627731d cc60987 627731d c8d5c42 627731d 4b05d70 c8d5c42 4b05d70 627731d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | ---
base_model: google-t5/t5-base
datasets:
- gokaygokay/prompt-enhancer-dataset
language:
- en
library_name: transformers
license: apache-2.0
pipeline_tag: text2text-generation
---
```python
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
device = "cuda" if torch.cuda.is_available() else "cpu"
# Model checkpoint
model_checkpoint = "Hatman/Flux-Prompt-Enhance"
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
# Model
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
enhancer = pipeline('text2text-generation',
model=model,
tokenizer=tokenizer,
repetition_penalty= 1.2,
device=device)
max_target_length = 256
prefix = "enhance prompt: "
short_prompt = "beautiful house with text 'hello'"
answer = enhancer(prefix + short_prompt, max_length=max_target_length)
final_answer = answer[0]['generated_text']
print(final_answer)
# a two-story house with white trim, large windows on the second floor,
# three chimneys on the roof, green trees and shrubs in front of the house,
# stone pathway leading to the front door, text on the house reads "hello" in all caps,
# blue sky above, shadows cast by the trees, sunlight creating contrast on the house's facade,
# some plants visible near the bottom right corner, overall warm and serene atmosphere.
```
<h1>A Script for Comfy</h1>
```python
import torch
import random
import hashlib
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
class PromptEnhancer:
def __init__(self):
# Set up device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Model checkpoint
self.model_checkpoint = "Hatman/Flux-Prompt-Enhance"
# Tokenizer and Model
self.tokenizer = AutoTokenizer.from_pretrained(self.model_checkpoint)
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_checkpoint).to(self.device)
# Initialize the node title and generated prompt
self.node_title = "Prompt Enhancer"
self.generated_prompt = ""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": ("STRING",),
"seed": ("INT", {"default": 42, "min": 0, "max": 4294967295}), # Default seed, larger range
"repetition_penalty": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0}), # Default repetition penalty
"max_target_length": ("INT", {"default": 256, "min": 1, "max": 1024}), # Default max target length
"temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0}), # Default temperature
"top_k": ("INT", {"default": 50, "min": 1, "max": 1000}), # Default top-k
"top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0}), # Default top-p
},
"optional": {
"prompts_list": ("LIST",), # List of prompts
}
}
RETURN_TYPES = ("STRING",) # Return only one string: the enhanced prompt
FUNCTION = "enhance_prompt"
CATEGORY = "TextEnhancement"
def generate_large_seed(self, seed, prompt):
# Combine the seed and prompt to create a unique string
unique_string = f"{seed}_{prompt}"
# Use a hash function to generate a large seed
hash_object = hashlib.sha256(unique_string.encode())
large_seed = int(hash_object.hexdigest(), 16) % (2**32)
return large_seed
def enhance_prompt(self, prompt, seed=42, repetition_penalty=1.2, max_target_length=256, temperature=0.7, top_k=50, top_p=0.9, prompts_list=None):
# Generate a large seed value
large_seed = self.generate_large_seed(seed, prompt)
# Set random seed for reproducibility
torch.manual_seed(large_seed)
random.seed(large_seed)
# Determine the prompts to process
prompts = [prompt] if prompts_list is None else prompts_list
enhanced_prompts = []
for p in prompts:
# Enhance prompt
prefix = "enhance prompt: "
input_text = prefix + p
input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to(self.device)
# Generate a random seed for this generation
random_seed = torch.randint(0, 2**32 - 1, (1,)).item()
torch.manual_seed(random_seed)
random.seed(random_seed)
outputs = self.model.generate(
input_ids,
max_length=max_target_length,
num_return_sequences=1,
do_sample=True,
temperature=temperature,
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p
)
final_answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
confidence_score = 1.0 # Default to 1.0 if no score is provided
# Print the generated prompt and confidence score
print(f"Generated Prompt: {final_answer} (Confidence: {confidence_score:.2f})")
enhanced_prompts.append((f"Enhanced Prompt: {final_answer}", confidence_score))
# Update the node title and generated prompt
if prompts_list is None:
self.node_title = f"Prompt Enhancer (Confidence: {confidence_score:.2f})"
self.generated_prompt = f"Enhanced Prompt: {final_answer}"
return (f"Enhanced Prompt: {final_answer}",)
else:
self.node_title = "Prompt Enhancer (Multiple Prompts)"
self.generated_prompt = "Multiple Prompts"
return enhanced_prompts
@property
def NODE_TITLE(self):
return self.node_title
@property
def GENERATED_PROMPT(self):
return self.generated_prompt
# A dictionary that contains all nodes you want to export with their names
NODE_CLASS_MAPPINGS = {
"PromptEnhancer": PromptEnhancer
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"PromptEnhancer": "Prompt Enhancer"
}
``` |