import re import spaces import gradio as gr from threading import Thread from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer try: import nltk from nltk.tokenize import word_tokenize from nltk.chunk import ne_chunk from nltk.tag import pos_tag NLTK_AVAILABLE = True except ImportError: NLTK_AVAILABLE = False SYSTEM_PROMPT = """You are an expert creative director specializing in visual descriptions for image generation. Your task: Transform the user's concept into a rich, detailed image description while PRESERVING their core idea. IMPORTANT RULES: 1. Keep ALL key elements (intents, entities) from the original concept 2. Enhance with artistic details, NOT change the fundamental idea 3. Maintain the user's intended subject, action, and setting You should elaborate on: • Visual composition and perspective (bird's eye, close-up, wide angle, etc.) • Artistic style (photorealistic, impressionist, specific artist like Van Gogh, etc.) • Color palette and color temperature • Lighting (golden hour, dramatic shadows, soft diffused, etc.) • Atmosphere and mood • Textures and materials (rough, smooth, metallic, organic, etc.) • Technical details (medium, brushwork, rendering style) • Environmental context (time of day, weather, season, era) • Level of detail and focus points Output format: A single, flowing paragraph that reads naturally as an image prompt.""" CUDA_AVAILABLE = False models = {} tokenizers = {} models[False] = AutoModelForCausalLM.from_pretrained("shb777/PromptTuner-v0.1") tokenizers[False] = AutoTokenizer.from_pretrained("shb777/PromptTuner-v0.1") models[False].eval() if CUDA_AVAILABLE: models[True] = AutoModelForCausalLM.from_pretrained("shb777/PromptTuner-v0.1").to('cuda') tokenizers[True] = tokenizers[False] models[True].eval() # Download NLTK data if NLTK_AVAILABLE: try: nltk.data.find('tokenizers/punkt') nltk.data.find('taggers/averaged_perceptron_tagger') nltk.data.find('chunkers/maxent_ne_chunker') nltk.data.find('corpora/words') except LookupError: nltk.download('punkt', quiet=True) nltk.download('averaged_perceptron_tagger', quiet=True) nltk.download('maxent_ne_chunker', quiet=True) nltk.download('words', quiet=True) def extract_key_phrases(text: str) -> list: if not NLTK_AVAILABLE: words = re.findall(r'\b[a-zA-Z]{3,}\b', text.lower()) return list(set(words)) phrases = [] try: tokens = word_tokenize(text) tagged = pos_tag(tokens) chunks = ne_chunk(tagged) current_phrase = [] for chunk in chunks: if hasattr(chunk, 'label'): phrase = ' '.join([token for token, _ in chunk.leaves()]) phrases.append(phrase.lower()) elif chunk[1].startswith('NN'): current_phrase.append(chunk[0]) elif chunk[1].startswith('JJ') and current_phrase: current_phrase.append(chunk[0]) else: if current_phrase: phrases.append(' '.join(current_phrase).lower()) current_phrase = [] if current_phrase: phrases.append(' '.join(current_phrase).lower()) for word, tag in tagged: if tag.startswith('JJ') or tag in ('RB', 'RBR', 'RBS'): phrases.append(word.lower()) except Exception: words = re.findall(r'\b[a-zA-Z]{3,}\b', text.lower()) phrases = list(set(words)) # Also include original multi-word phrases multi_word = re.findall(r'\b[a-zA-Z]{3,}(?:\s+[a-zA-Z]{3,}){1,3}\b', text) phrases.extend([mw.lower() for mw in multi_word]) # Sort by length (longer first) and remove duplicates phrases = list(set(phrases)) phrases.sort(key=len, reverse=True) return phrases[:20] def highlight_matches(original_input: str, enhanced_output: str) -> str: if not original_input.strip(): return f'

{enhanced_output}

' key_phrases = extract_key_phrases(original_input) if not key_phrases: return f'

{enhanced_output}

' # Sort by length (longer phrases first) key_phrases.sort(key=len, reverse=True) output = enhanced_output highlighted_spans = [] for phrase in key_phrases: pattern = re.compile(r'\b' + re.escape(phrase) + r'\b', re.IGNORECASE) def replace_with_highlight(match): matched_text = match.group(0) start = match.start() # Skip if already highlighted for h_start, h_end in highlighted_spans: if start >= h_start and start <= h_end: return matched_text highlighted_spans.append((start, match.end())) return f'{matched_text}' output = pattern.sub(replace_with_highlight, output) return f'

{output}

' @spaces.GPU(duration=30) def generate_gpu(inputs, generation_kwargs): return models[True].generate(**inputs, **generation_kwargs) def enhance_prompt(user_prompt: str, use_gpu=CUDA_AVAILABLE): """Enhance the user's prompt using the AI model.""" # Validate input if not user_prompt or not user_prompt.strip(): yield ( 'Please enter a prompt to enhance.', "", gr.update(interactive=True), gr.update(interactive=True) ) return # Prepare messages messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt} ] use_gpu = use_gpu and CUDA_AVAILABLE tokenizer = tokenizers[False] # Tokenize input prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt") if use_gpu: inputs = {k: v.to('cuda') for k, v in inputs.items()} # Set up streaming streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) generation_kwargs = { 'max_new_tokens': 512, 'streamer': streamer, 'do_sample': True, 'temperature': 1, 'top_p': 0.95, 'top_k': 64 } # Show loading state placeholder = 'Your enhanced prompt will appear here' yield placeholder, "", gr.update(interactive=False), gr.update(interactive=False) try: # Start generation in a separate thread if use_gpu: thread = Thread(target=generate_gpu, kwargs={'inputs': inputs, 'generation_kwargs': generation_kwargs}) else: thread = Thread(target=models[False].generate, kwargs={**inputs, **generation_kwargs}) thread.start() # Stream output output = "" for text in streamer: output += text highlighted = highlight_matches(user_prompt, output) yield highlighted, output, gr.update(), gr.update() except gr.exceptions.Error as e: if use_gpu: gr.Warning(str(e)) gr.Info('Retrying with CPU') inputs = {k: v.cpu() for k, v in inputs.items()} thread = Thread(target=models[False].generate, kwargs={**inputs, **generation_kwargs}) thread.start() output = "" for text in streamer: output += text highlighted = highlight_matches(user_prompt, output) yield highlighted, output, gr.update(), gr.update() else: raise gr.Error(e) # Final output with interactive buttons restored final_highlighted = highlight_matches(user_prompt, output) yield final_highlighted, output, gr.update(interactive=True), gr.update(interactive=True) # ============================================================================= # CSS - shadcn/ui inspired Zinc Dark Theme # ============================================================================= custom_css = """ /* ========== CSS VARIABLES ========== */ :root { --background: 240 10% 3.9%; --foreground: 0 0% 98%; --card: 240 10% 4.5%; --card-border: 240 3.7% 18%; --primary: 0 0% 98%; --primary-foreground: 240 5.9% 10%; --secondary: 240 3.7% 15.9%; --secondary-foreground: 0 0% 98%; --muted: 240 3.7% 15.9%; --muted-foreground: 240 5% 64.9%; --accent: 240 3.7% 15.9%; --accent-foreground: 0 0% 98%; --border: 240 3.7% 18%; --input: 240 3.7% 18%; --ring: 240 5.9% 85%; --radius: 0.625rem; } /* ========== GLOBAL STYLES ========== */ .gradio-container { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; background: hsl(var(--background)) !important; color: hsl(var(--foreground)); } .gradio-container mark { background: hsl(var(--accent) / 0.6); color: hsl(var(--accent-foreground)); padding: 0.15em 0.35em; border-radius: calc(var(--radius) - 2px); font-weight: 500; border: 1px solid hsl(var(--border) / 0.5); } footer { display: none !important; } /* ========== MARKDOWN ========== */ .gradio-markdown { color: hsl(var(--foreground)) !important; font-size: 0.9375rem !important; line-height: 1.6 !important; } .gradio-markdown:first-child { margin-bottom: 2rem; padding-bottom: 1.5rem; border-bottom: 1px solid hsl(var(--border)); } .gradio-markdown:last-child { padding-top: 1.5rem; border-top: 1px solid hsl(var(--border)); color: hsl(var(--muted-foreground)) !important; } .gradio-markdown a { color: hsl(var(--foreground)) !important; text-decoration: none; border-bottom: 1px solid hsl(var(--border)); transition: border-color 0.2s ease; } .gradio-markdown a:hover { border-color: hsl(var(--ring)); } /* ========== LAYOUT ========== */ .main-grid { display: grid; grid-template-columns: 1fr 1fr; gap: 2rem; } @media (max-width: 768px) { .main-grid { grid-template-columns: 1fr; } } /* ========== CARDS ========== */ .card { background: hsl(var(--card)); border: 1px solid hsl(var(--card-border)); border-radius: var(--radius); padding: 1.5rem; box-shadow: 0 1px 2px rgba(0, 0, 0, 0.3), 0 0 0 1px rgba(255, 255, 255, 0.02) inset; } /* ========== FORM ELEMENTS ========== */ .form-label { font-size: 0.875rem; font-weight: 500; margin-bottom: 0.5rem; display: block; color: hsl(var(--foreground)); } .input-textarea { width: 100%; min-height: 140px; padding: 0.875rem; font-size: 0.9375rem; line-height: 1.6; background: hsl(var(--background)); border: 1px solid hsl(var(--input)); border-radius: var(--radius); color: hsl(var(--foreground)); transition: all 0.2s ease; resize: vertical; box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2); } .input-textarea::placeholder { color: hsl(var(--muted-foreground) / 0.7); } .input-textarea:focus { outline: none; border-color: hsl(var(--ring)); box-shadow: 0 0 0 3px hsl(var(--ring) / 0.1), 0 1px 2px rgba(0, 0, 0, 0.2); background: hsl(var(--background) / 0.8); } /* ========== BUTTONS ========== */ .btn { display: inline-flex; align-items: center; justify-content: center; gap: 0.5rem; font-size: 0.9375rem; font-weight: 500; padding: 0.625rem 1.25rem; border-radius: var(--radius); cursor: pointer; transition: all 0.2s ease; border: none; } .btn:focus-visible { outline: none; box-shadow: 0 0 0 2px hsl(var(--background)), 0 0 0 4px hsl(var(--ring)); } .btn-primary { background: hsl(var(--primary)); color: hsl(var(--primary-foreground)); box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2), 0 0 0 1px rgba(255, 255, 255, 0.05) inset; } .btn-primary:hover { opacity: 0.95; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.25), 0 0 0 1px rgba(255, 255, 255, 0.08) inset; } .btn-primary:active { transform: translateY(1px); } .btn-primary:disabled { opacity: 0.5; cursor: not-allowed; } .btn-secondary { background: hsl(var(--secondary)); color: hsl(var(--secondary-foreground)); border: 1px solid hsl(var(--border)); box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2); } .btn-secondary:hover { background: hsl(var(--secondary) / 0.8); border-color: hsl(var(--muted-foreground) / 0.5); } .btn-secondary:active { transform: translateY(1px); } /* ========== OUTPUT CONTAINER ========== */ .output-container { min-height: 140px; padding: 0.875rem; border: 1px solid hsl(var(--input)); border-radius: var(--radius); background: hsl(var(--background)); box-shadow: 0 1px 2px rgba(0, 0, 0, 0.15), 0 0 0 1px rgba(255, 255, 255, 0.02) inset; } .output-text { color: hsl(var(--foreground)); font-size: 0.9375rem; line-height: 1.75; margin: 0; } .placeholder-text { color: hsl(var(--muted-foreground)); } .highlight-keyword { background: hsl(var(--accent) / 0.6); color: hsl(var(--accent-foreground)); padding: 0.15em 0.35em; border-radius: calc(var(--radius) - 2px); font-weight: 500; border: 1px solid hsl(var(--border) / 0.5); } /* ========== EXAMPLES ========== */ .examples-section { padding: 1.5rem; background: hsl(var(--card)); border: 1px solid hsl(var(--card-border)); border-radius: var(--radius); box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2), 0 0 0 1px rgba(255, 255, 255, 0.02) inset; } /* ========== SPACING UTILITIES ========== */ .mt-6 { margin-top: 1.5rem; } .flex { display: flex; } .gap-2 { gap: 0.5rem; } """ # ============================================================================= # Gradio Interface # ============================================================================= with gr.Blocks(css=custom_css, title="Prompt Enhancer") as demo: # Header with gr.Row(): gr.Markdown("Transform your creative ideas into detailed, vivid prompts for AI image generation.") # Main content - two column layout with gr.Row(elem_classes=["main-grid"]): # Input column with gr.Column(elem_classes=["card"]): gr.HTML('') input_text = gr.Textbox( placeholder="Describe your image concept... e.g., fox, red tail, blue moon, clouds", lines=5, show_label=False, autofocus=True, container=False, elem_classes=["input-textarea"] ) with gr.Row(elem_classes=["flex gap-2 mt-6"]): enhance_btn = gr.Button( "Enhance Prompt", variant="primary", scale=2, elem_classes=["btn", "btn-primary"] ) clear_btn = gr.Button( "Clear", scale=1, elem_classes=["btn", "btn-secondary"] ) # Output column with gr.Column(elem_classes=["card"]): gr.HTML('') output_html = gr.HTML( value='Your enhanced prompt will appear here', elem_classes=["output-container"] ) raw_output = gr.Textbox(visible=False) # Examples section with gr.Column(elem_classes=["examples-section"]): gr.Examples( examples=[ ["fox, red tail, blue moon, clouds"], ["room with french window, cozy morning vibes, minimal"], ["anime style, sunset, japan"] ], inputs=input_text, label="Examples" ) # Footer with gr.Row(): gr.Markdown( "Powered by [PromptTuner](https://huggingface.co/shb777/PromptTuner-v0.1), " "a finetuned gemma3-270M model specifically designed to enhance text prompts " "for text-to-image generation." ) # ============================================================================= # Event Handlers # ============================================================================= enhance_btn.click( fn=enhance_prompt, inputs=[input_text, gr.State(False)], outputs=[output_html, raw_output, enhance_btn, clear_btn] ) clear_btn.click( fn=lambda: ( "", 'Your enhanced prompt will appear here', "", gr.update(interactive=True), gr.update(interactive=True) ), inputs=None, outputs=[input_text, output_html, raw_output, enhance_btn, clear_btn] ) if __name__ == "__main__": demo.queue(max_size=20).launch(mcp_server=True)