Spaces:
Running
on
Zero
Running
on
Zero
| import re | |
| import spaces | |
| import gradio as gr | |
| from threading import Thread | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| try: | |
| import nltk | |
| from nltk.tokenize import word_tokenize | |
| from nltk.chunk import ne_chunk | |
| from nltk.tag import pos_tag | |
| NLTK_AVAILABLE = True | |
| except ImportError: | |
| NLTK_AVAILABLE = False | |
| SYSTEM_PROMPT = """You are an expert creative director specializing in visual descriptions for image generation. | |
| Your task: Transform the user's concept into a rich, detailed image description while PRESERVING their core idea. | |
| IMPORTANT RULES: | |
| 1. Keep ALL key elements (intents, entities) from the original concept | |
| 2. Enhance with artistic details, NOT change the fundamental idea | |
| 3. Maintain the user's intended subject, action, and setting | |
| You should elaborate on: | |
| • Visual composition and perspective (bird's eye, close-up, wide angle, etc.) | |
| • Artistic style (photorealistic, impressionist, specific artist like Van Gogh, etc.) | |
| • Color palette and color temperature | |
| • Lighting (golden hour, dramatic shadows, soft diffused, etc.) | |
| • Atmosphere and mood | |
| • Textures and materials (rough, smooth, metallic, organic, etc.) | |
| • Technical details (medium, brushwork, rendering style) | |
| • Environmental context (time of day, weather, season, era) | |
| • Level of detail and focus points | |
| Output format: A single, flowing paragraph that reads naturally as an image prompt.""" | |
| CUDA_AVAILABLE = False | |
| models = {} | |
| tokenizers = {} | |
| models[False] = AutoModelForCausalLM.from_pretrained("shb777/PromptTuner-v0.1") | |
| tokenizers[False] = AutoTokenizer.from_pretrained("shb777/PromptTuner-v0.1") | |
| models[False].eval() | |
| if CUDA_AVAILABLE: | |
| models[True] = AutoModelForCausalLM.from_pretrained("shb777/PromptTuner-v0.1").to('cuda') | |
| tokenizers[True] = tokenizers[False] | |
| models[True].eval() | |
| # Download NLTK data | |
| if NLTK_AVAILABLE: | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| nltk.data.find('taggers/averaged_perceptron_tagger') | |
| nltk.data.find('chunkers/maxent_ne_chunker') | |
| nltk.data.find('corpora/words') | |
| except LookupError: | |
| nltk.download('punkt', quiet=True) | |
| nltk.download('averaged_perceptron_tagger', quiet=True) | |
| nltk.download('maxent_ne_chunker', quiet=True) | |
| nltk.download('words', quiet=True) | |
| def extract_key_phrases(text: str) -> list: | |
| if not NLTK_AVAILABLE: | |
| words = re.findall(r'\b[a-zA-Z]{3,}\b', text.lower()) | |
| return list(set(words)) | |
| phrases = [] | |
| try: | |
| tokens = word_tokenize(text) | |
| tagged = pos_tag(tokens) | |
| chunks = ne_chunk(tagged) | |
| current_phrase = [] | |
| for chunk in chunks: | |
| if hasattr(chunk, 'label'): | |
| phrase = ' '.join([token for token, _ in chunk.leaves()]) | |
| phrases.append(phrase.lower()) | |
| elif chunk[1].startswith('NN'): | |
| current_phrase.append(chunk[0]) | |
| elif chunk[1].startswith('JJ') and current_phrase: | |
| current_phrase.append(chunk[0]) | |
| else: | |
| if current_phrase: | |
| phrases.append(' '.join(current_phrase).lower()) | |
| current_phrase = [] | |
| if current_phrase: | |
| phrases.append(' '.join(current_phrase).lower()) | |
| for word, tag in tagged: | |
| if tag.startswith('JJ') or tag in ('RB', 'RBR', 'RBS'): | |
| phrases.append(word.lower()) | |
| except Exception: | |
| words = re.findall(r'\b[a-zA-Z]{3,}\b', text.lower()) | |
| phrases = list(set(words)) | |
| # Also include original multi-word phrases | |
| multi_word = re.findall(r'\b[a-zA-Z]{3,}(?:\s+[a-zA-Z]{3,}){1,3}\b', text) | |
| phrases.extend([mw.lower() for mw in multi_word]) | |
| # Sort by length (longer first) and remove duplicates | |
| phrases = list(set(phrases)) | |
| phrases.sort(key=len, reverse=True) | |
| return phrases[:20] | |
| def highlight_matches(original_input: str, enhanced_output: str) -> str: | |
| if not original_input.strip(): | |
| return f'<p class="output-text">{enhanced_output}</p>' | |
| key_phrases = extract_key_phrases(original_input) | |
| if not key_phrases: | |
| return f'<p class="output-text">{enhanced_output}</p>' | |
| # Sort by length (longer phrases first) | |
| key_phrases.sort(key=len, reverse=True) | |
| output = enhanced_output | |
| highlighted_spans = [] | |
| for phrase in key_phrases: | |
| pattern = re.compile(r'\b' + re.escape(phrase) + r'\b', re.IGNORECASE) | |
| def replace_with_highlight(match): | |
| matched_text = match.group(0) | |
| start = match.start() | |
| # Skip if already highlighted | |
| for h_start, h_end in highlighted_spans: | |
| if start >= h_start and start <= h_end: | |
| return matched_text | |
| highlighted_spans.append((start, match.end())) | |
| return f'<mark class="highlight-keyword">{matched_text}</mark>' | |
| output = pattern.sub(replace_with_highlight, output) | |
| return f'<p class="output-text">{output}</p>' | |
| def generate_gpu(inputs, generation_kwargs): | |
| return models[True].generate(**inputs, **generation_kwargs) | |
| def enhance_prompt(user_prompt: str, use_gpu=CUDA_AVAILABLE): | |
| """Enhance the user's prompt using the AI model.""" | |
| # Validate input | |
| if not user_prompt or not user_prompt.strip(): | |
| yield ( | |
| '<span class="placeholder-text">Please enter a prompt to enhance.</span>', | |
| "", | |
| gr.update(interactive=True), | |
| gr.update(interactive=True) | |
| ) | |
| return | |
| # Prepare messages | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| use_gpu = use_gpu and CUDA_AVAILABLE | |
| tokenizer = tokenizers[False] | |
| # Tokenize input | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| if use_gpu: | |
| inputs = {k: v.to('cuda') for k, v in inputs.items()} | |
| # Set up streaming | |
| streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) | |
| generation_kwargs = { | |
| 'max_new_tokens': 512, | |
| 'streamer': streamer, | |
| 'do_sample': True, | |
| 'temperature': 1, | |
| 'top_p': 0.95, | |
| 'top_k': 64 | |
| } | |
| # Show loading state | |
| placeholder = '<span class="placeholder-text">Your enhanced prompt will appear here</span>' | |
| yield placeholder, "", gr.update(interactive=False), gr.update(interactive=False) | |
| try: | |
| # Start generation in a separate thread | |
| if use_gpu: | |
| thread = Thread(target=generate_gpu, kwargs={'inputs': inputs, 'generation_kwargs': generation_kwargs}) | |
| else: | |
| thread = Thread(target=models[False].generate, kwargs={**inputs, **generation_kwargs}) | |
| thread.start() | |
| # Stream output | |
| output = "" | |
| for text in streamer: | |
| output += text | |
| highlighted = highlight_matches(user_prompt, output) | |
| yield highlighted, output, gr.update(), gr.update() | |
| except gr.exceptions.Error as e: | |
| if use_gpu: | |
| gr.Warning(str(e)) | |
| gr.Info('Retrying with CPU') | |
| inputs = {k: v.cpu() for k, v in inputs.items()} | |
| thread = Thread(target=models[False].generate, kwargs={**inputs, **generation_kwargs}) | |
| thread.start() | |
| output = "" | |
| for text in streamer: | |
| output += text | |
| highlighted = highlight_matches(user_prompt, output) | |
| yield highlighted, output, gr.update(), gr.update() | |
| else: | |
| raise gr.Error(e) | |
| # Final output with interactive buttons restored | |
| final_highlighted = highlight_matches(user_prompt, output) | |
| yield final_highlighted, output, gr.update(interactive=True), gr.update(interactive=True) | |
| # ============================================================================= | |
| # CSS - shadcn/ui inspired Zinc Dark Theme | |
| # ============================================================================= | |
| custom_css = """ | |
| /* ========== CSS VARIABLES ========== */ | |
| :root { | |
| --background: 240 10% 3.9%; | |
| --foreground: 0 0% 98%; | |
| --card: 240 10% 4.5%; | |
| --card-border: 240 3.7% 18%; | |
| --primary: 0 0% 98%; | |
| --primary-foreground: 240 5.9% 10%; | |
| --secondary: 240 3.7% 15.9%; | |
| --secondary-foreground: 0 0% 98%; | |
| --muted: 240 3.7% 15.9%; | |
| --muted-foreground: 240 5% 64.9%; | |
| --accent: 240 3.7% 15.9%; | |
| --accent-foreground: 0 0% 98%; | |
| --border: 240 3.7% 18%; | |
| --input: 240 3.7% 18%; | |
| --ring: 240 5.9% 85%; | |
| --radius: 0.625rem; | |
| } | |
| /* ========== GLOBAL STYLES ========== */ | |
| .gradio-container { | |
| font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; | |
| background: hsl(var(--background)) !important; | |
| color: hsl(var(--foreground)); | |
| } | |
| .gradio-container mark { | |
| background: hsl(var(--accent) / 0.6); | |
| color: hsl(var(--accent-foreground)); | |
| padding: 0.15em 0.35em; | |
| border-radius: calc(var(--radius) - 2px); | |
| font-weight: 500; | |
| border: 1px solid hsl(var(--border) / 0.5); | |
| } | |
| footer { display: none !important; } | |
| /* ========== MARKDOWN ========== */ | |
| .gradio-markdown { | |
| color: hsl(var(--foreground)) !important; | |
| font-size: 0.9375rem !important; | |
| line-height: 1.6 !important; | |
| } | |
| .gradio-markdown:first-child { | |
| margin-bottom: 2rem; | |
| padding-bottom: 1.5rem; | |
| border-bottom: 1px solid hsl(var(--border)); | |
| } | |
| .gradio-markdown:last-child { | |
| padding-top: 1.5rem; | |
| border-top: 1px solid hsl(var(--border)); | |
| color: hsl(var(--muted-foreground)) !important; | |
| } | |
| .gradio-markdown a { | |
| color: hsl(var(--foreground)) !important; | |
| text-decoration: none; | |
| border-bottom: 1px solid hsl(var(--border)); | |
| transition: border-color 0.2s ease; | |
| } | |
| .gradio-markdown a:hover { | |
| border-color: hsl(var(--ring)); | |
| } | |
| /* ========== LAYOUT ========== */ | |
| .main-grid { | |
| display: grid; | |
| grid-template-columns: 1fr 1fr; | |
| gap: 2rem; | |
| } | |
| @media (max-width: 768px) { | |
| .main-grid { grid-template-columns: 1fr; } | |
| } | |
| /* ========== CARDS ========== */ | |
| .card { | |
| background: hsl(var(--card)); | |
| border: 1px solid hsl(var(--card-border)); | |
| border-radius: var(--radius); | |
| padding: 1.5rem; | |
| box-shadow: 0 1px 2px rgba(0, 0, 0, 0.3), 0 0 0 1px rgba(255, 255, 255, 0.02) inset; | |
| } | |
| /* ========== FORM ELEMENTS ========== */ | |
| .form-label { | |
| font-size: 0.875rem; | |
| font-weight: 500; | |
| margin-bottom: 0.5rem; | |
| display: block; | |
| color: hsl(var(--foreground)); | |
| } | |
| .input-textarea { | |
| width: 100%; | |
| min-height: 140px; | |
| padding: 0.875rem; | |
| font-size: 0.9375rem; | |
| line-height: 1.6; | |
| background: hsl(var(--background)); | |
| border: 1px solid hsl(var(--input)); | |
| border-radius: var(--radius); | |
| color: hsl(var(--foreground)); | |
| transition: all 0.2s ease; | |
| resize: vertical; | |
| box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2); | |
| } | |
| .input-textarea::placeholder { | |
| color: hsl(var(--muted-foreground) / 0.7); | |
| } | |
| .input-textarea:focus { | |
| outline: none; | |
| border-color: hsl(var(--ring)); | |
| box-shadow: 0 0 0 3px hsl(var(--ring) / 0.1), 0 1px 2px rgba(0, 0, 0, 0.2); | |
| background: hsl(var(--background) / 0.8); | |
| } | |
| /* ========== BUTTONS ========== */ | |
| .btn { | |
| display: inline-flex; | |
| align-items: center; | |
| justify-content: center; | |
| gap: 0.5rem; | |
| font-size: 0.9375rem; | |
| font-weight: 500; | |
| padding: 0.625rem 1.25rem; | |
| border-radius: var(--radius); | |
| cursor: pointer; | |
| transition: all 0.2s ease; | |
| border: none; | |
| } | |
| .btn:focus-visible { | |
| outline: none; | |
| box-shadow: 0 0 0 2px hsl(var(--background)), 0 0 0 4px hsl(var(--ring)); | |
| } | |
| .btn-primary { | |
| background: hsl(var(--primary)); | |
| color: hsl(var(--primary-foreground)); | |
| box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2), 0 0 0 1px rgba(255, 255, 255, 0.05) inset; | |
| } | |
| .btn-primary:hover { | |
| opacity: 0.95; | |
| box-shadow: 0 2px 4px rgba(0, 0, 0, 0.25), 0 0 0 1px rgba(255, 255, 255, 0.08) inset; | |
| } | |
| .btn-primary:active { | |
| transform: translateY(1px); | |
| } | |
| .btn-primary:disabled { | |
| opacity: 0.5; | |
| cursor: not-allowed; | |
| } | |
| .btn-secondary { | |
| background: hsl(var(--secondary)); | |
| color: hsl(var(--secondary-foreground)); | |
| border: 1px solid hsl(var(--border)); | |
| box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2); | |
| } | |
| .btn-secondary:hover { | |
| background: hsl(var(--secondary) / 0.8); | |
| border-color: hsl(var(--muted-foreground) / 0.5); | |
| } | |
| .btn-secondary:active { | |
| transform: translateY(1px); | |
| } | |
| /* ========== OUTPUT CONTAINER ========== */ | |
| .output-container { | |
| min-height: 140px; | |
| padding: 0.875rem; | |
| border: 1px solid hsl(var(--input)); | |
| border-radius: var(--radius); | |
| background: hsl(var(--background)); | |
| box-shadow: 0 1px 2px rgba(0, 0, 0, 0.15), 0 0 0 1px rgba(255, 255, 255, 0.02) inset; | |
| } | |
| .output-text { | |
| color: hsl(var(--foreground)); | |
| font-size: 0.9375rem; | |
| line-height: 1.75; | |
| margin: 0; | |
| } | |
| .placeholder-text { | |
| color: hsl(var(--muted-foreground)); | |
| } | |
| .highlight-keyword { | |
| background: hsl(var(--accent) / 0.6); | |
| color: hsl(var(--accent-foreground)); | |
| padding: 0.15em 0.35em; | |
| border-radius: calc(var(--radius) - 2px); | |
| font-weight: 500; | |
| border: 1px solid hsl(var(--border) / 0.5); | |
| } | |
| /* ========== EXAMPLES ========== */ | |
| .examples-section { | |
| padding: 1.5rem; | |
| background: hsl(var(--card)); | |
| border: 1px solid hsl(var(--card-border)); | |
| border-radius: var(--radius); | |
| box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2), 0 0 0 1px rgba(255, 255, 255, 0.02) inset; | |
| } | |
| /* ========== SPACING UTILITIES ========== */ | |
| .mt-6 { margin-top: 1.5rem; } | |
| .flex { display: flex; } | |
| .gap-2 { gap: 0.5rem; } | |
| """ | |
| # ============================================================================= | |
| # Gradio Interface | |
| # ============================================================================= | |
| with gr.Blocks(css=custom_css, title="Prompt Enhancer") as demo: | |
| # Header | |
| with gr.Row(): | |
| gr.Markdown("Transform your creative ideas into detailed, vivid prompts for AI image generation.") | |
| # Main content - two column layout | |
| with gr.Row(elem_classes=["main-grid"]): | |
| # Input column | |
| with gr.Column(elem_classes=["card"]): | |
| gr.HTML('<label class="form-label">Input Prompt</label>') | |
| input_text = gr.Textbox( | |
| placeholder="Describe your image concept... e.g., fox, red tail, blue moon, clouds", | |
| lines=5, | |
| show_label=False, | |
| autofocus=True, | |
| container=False, | |
| elem_classes=["input-textarea"] | |
| ) | |
| with gr.Row(elem_classes=["flex gap-2 mt-6"]): | |
| enhance_btn = gr.Button( | |
| "Enhance Prompt", | |
| variant="primary", | |
| scale=2, | |
| elem_classes=["btn", "btn-primary"] | |
| ) | |
| clear_btn = gr.Button( | |
| "Clear", | |
| scale=1, | |
| elem_classes=["btn", "btn-secondary"] | |
| ) | |
| # Output column | |
| with gr.Column(elem_classes=["card"]): | |
| gr.HTML('<label class="form-label">Enhanced Prompt</label>') | |
| output_html = gr.HTML( | |
| value='<span class="placeholder-text">Your enhanced prompt will appear here</span>', | |
| elem_classes=["output-container"] | |
| ) | |
| raw_output = gr.Textbox(visible=False) | |
| # Examples section | |
| with gr.Column(elem_classes=["examples-section"]): | |
| gr.Examples( | |
| examples=[ | |
| ["fox, red tail, blue moon, clouds"], | |
| ["room with french window, cozy morning vibes, minimal"], | |
| ["anime style, sunset, japan"] | |
| ], | |
| inputs=input_text, | |
| label="Examples" | |
| ) | |
| # Footer | |
| with gr.Row(): | |
| gr.Markdown( | |
| "Powered by [PromptTuner](https://huggingface.co/shb777/PromptTuner-v0.1), " | |
| "a finetuned gemma3-270M model specifically designed to enhance text prompts " | |
| "for text-to-image generation." | |
| ) | |
| # ============================================================================= | |
| # Event Handlers | |
| # ============================================================================= | |
| enhance_btn.click( | |
| fn=enhance_prompt, | |
| inputs=[input_text, gr.State(False)], | |
| outputs=[output_html, raw_output, enhance_btn, clear_btn] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ( | |
| "", | |
| '<span class="placeholder-text">Your enhanced prompt will appear here</span>', | |
| "", | |
| gr.update(interactive=True), | |
| gr.update(interactive=True) | |
| ), | |
| inputs=None, | |
| outputs=[input_text, output_html, raw_output, enhance_btn, clear_btn] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch(mcp_server=True) | |