TinkerSpace / app.py
shb777's picture
Super-squash branch 'main' using huggingface_hub
c3b6ca1 verified
import re
import spaces
import gradio as gr
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
try:
import nltk
from nltk.tokenize import word_tokenize
from nltk.chunk import ne_chunk
from nltk.tag import pos_tag
NLTK_AVAILABLE = True
except ImportError:
NLTK_AVAILABLE = False
SYSTEM_PROMPT = """You are an expert creative director specializing in visual descriptions for image generation.
Your task: Transform the user's concept into a rich, detailed image description while PRESERVING their core idea.
IMPORTANT RULES:
1. Keep ALL key elements (intents, entities) from the original concept
2. Enhance with artistic details, NOT change the fundamental idea
3. Maintain the user's intended subject, action, and setting
You should elaborate on:
• Visual composition and perspective (bird's eye, close-up, wide angle, etc.)
• Artistic style (photorealistic, impressionist, specific artist like Van Gogh, etc.)
• Color palette and color temperature
• Lighting (golden hour, dramatic shadows, soft diffused, etc.)
• Atmosphere and mood
• Textures and materials (rough, smooth, metallic, organic, etc.)
• Technical details (medium, brushwork, rendering style)
• Environmental context (time of day, weather, season, era)
• Level of detail and focus points
Output format: A single, flowing paragraph that reads naturally as an image prompt."""
CUDA_AVAILABLE = False
models = {}
tokenizers = {}
models[False] = AutoModelForCausalLM.from_pretrained("shb777/PromptTuner-v0.1")
tokenizers[False] = AutoTokenizer.from_pretrained("shb777/PromptTuner-v0.1")
models[False].eval()
if CUDA_AVAILABLE:
models[True] = AutoModelForCausalLM.from_pretrained("shb777/PromptTuner-v0.1").to('cuda')
tokenizers[True] = tokenizers[False]
models[True].eval()
# Download NLTK data
if NLTK_AVAILABLE:
try:
nltk.data.find('tokenizers/punkt')
nltk.data.find('taggers/averaged_perceptron_tagger')
nltk.data.find('chunkers/maxent_ne_chunker')
nltk.data.find('corpora/words')
except LookupError:
nltk.download('punkt', quiet=True)
nltk.download('averaged_perceptron_tagger', quiet=True)
nltk.download('maxent_ne_chunker', quiet=True)
nltk.download('words', quiet=True)
def extract_key_phrases(text: str) -> list:
if not NLTK_AVAILABLE:
words = re.findall(r'\b[a-zA-Z]{3,}\b', text.lower())
return list(set(words))
phrases = []
try:
tokens = word_tokenize(text)
tagged = pos_tag(tokens)
chunks = ne_chunk(tagged)
current_phrase = []
for chunk in chunks:
if hasattr(chunk, 'label'):
phrase = ' '.join([token for token, _ in chunk.leaves()])
phrases.append(phrase.lower())
elif chunk[1].startswith('NN'):
current_phrase.append(chunk[0])
elif chunk[1].startswith('JJ') and current_phrase:
current_phrase.append(chunk[0])
else:
if current_phrase:
phrases.append(' '.join(current_phrase).lower())
current_phrase = []
if current_phrase:
phrases.append(' '.join(current_phrase).lower())
for word, tag in tagged:
if tag.startswith('JJ') or tag in ('RB', 'RBR', 'RBS'):
phrases.append(word.lower())
except Exception:
words = re.findall(r'\b[a-zA-Z]{3,}\b', text.lower())
phrases = list(set(words))
# Also include original multi-word phrases
multi_word = re.findall(r'\b[a-zA-Z]{3,}(?:\s+[a-zA-Z]{3,}){1,3}\b', text)
phrases.extend([mw.lower() for mw in multi_word])
# Sort by length (longer first) and remove duplicates
phrases = list(set(phrases))
phrases.sort(key=len, reverse=True)
return phrases[:20]
def highlight_matches(original_input: str, enhanced_output: str) -> str:
if not original_input.strip():
return f'<p class="output-text">{enhanced_output}</p>'
key_phrases = extract_key_phrases(original_input)
if not key_phrases:
return f'<p class="output-text">{enhanced_output}</p>'
# Sort by length (longer phrases first)
key_phrases.sort(key=len, reverse=True)
output = enhanced_output
highlighted_spans = []
for phrase in key_phrases:
pattern = re.compile(r'\b' + re.escape(phrase) + r'\b', re.IGNORECASE)
def replace_with_highlight(match):
matched_text = match.group(0)
start = match.start()
# Skip if already highlighted
for h_start, h_end in highlighted_spans:
if start >= h_start and start <= h_end:
return matched_text
highlighted_spans.append((start, match.end()))
return f'<mark class="highlight-keyword">{matched_text}</mark>'
output = pattern.sub(replace_with_highlight, output)
return f'<p class="output-text">{output}</p>'
@spaces.GPU(duration=30)
def generate_gpu(inputs, generation_kwargs):
return models[True].generate(**inputs, **generation_kwargs)
def enhance_prompt(user_prompt: str, use_gpu=CUDA_AVAILABLE):
"""Enhance the user's prompt using the AI model."""
# Validate input
if not user_prompt or not user_prompt.strip():
yield (
'<span class="placeholder-text">Please enter a prompt to enhance.</span>',
"",
gr.update(interactive=True),
gr.update(interactive=True)
)
return
# Prepare messages
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt}
]
use_gpu = use_gpu and CUDA_AVAILABLE
tokenizer = tokenizers[False]
# Tokenize input
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt")
if use_gpu:
inputs = {k: v.to('cuda') for k, v in inputs.items()}
# Set up streaming
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_kwargs = {
'max_new_tokens': 512,
'streamer': streamer,
'do_sample': True,
'temperature': 1,
'top_p': 0.95,
'top_k': 64
}
# Show loading state
placeholder = '<span class="placeholder-text">Your enhanced prompt will appear here</span>'
yield placeholder, "", gr.update(interactive=False), gr.update(interactive=False)
try:
# Start generation in a separate thread
if use_gpu:
thread = Thread(target=generate_gpu, kwargs={'inputs': inputs, 'generation_kwargs': generation_kwargs})
else:
thread = Thread(target=models[False].generate, kwargs={**inputs, **generation_kwargs})
thread.start()
# Stream output
output = ""
for text in streamer:
output += text
highlighted = highlight_matches(user_prompt, output)
yield highlighted, output, gr.update(), gr.update()
except gr.exceptions.Error as e:
if use_gpu:
gr.Warning(str(e))
gr.Info('Retrying with CPU')
inputs = {k: v.cpu() for k, v in inputs.items()}
thread = Thread(target=models[False].generate, kwargs={**inputs, **generation_kwargs})
thread.start()
output = ""
for text in streamer:
output += text
highlighted = highlight_matches(user_prompt, output)
yield highlighted, output, gr.update(), gr.update()
else:
raise gr.Error(e)
# Final output with interactive buttons restored
final_highlighted = highlight_matches(user_prompt, output)
yield final_highlighted, output, gr.update(interactive=True), gr.update(interactive=True)
# =============================================================================
# CSS - shadcn/ui inspired Zinc Dark Theme
# =============================================================================
custom_css = """
/* ========== CSS VARIABLES ========== */
:root {
--background: 240 10% 3.9%;
--foreground: 0 0% 98%;
--card: 240 10% 4.5%;
--card-border: 240 3.7% 18%;
--primary: 0 0% 98%;
--primary-foreground: 240 5.9% 10%;
--secondary: 240 3.7% 15.9%;
--secondary-foreground: 0 0% 98%;
--muted: 240 3.7% 15.9%;
--muted-foreground: 240 5% 64.9%;
--accent: 240 3.7% 15.9%;
--accent-foreground: 0 0% 98%;
--border: 240 3.7% 18%;
--input: 240 3.7% 18%;
--ring: 240 5.9% 85%;
--radius: 0.625rem;
}
/* ========== GLOBAL STYLES ========== */
.gradio-container {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
background: hsl(var(--background)) !important;
color: hsl(var(--foreground));
}
.gradio-container mark {
background: hsl(var(--accent) / 0.6);
color: hsl(var(--accent-foreground));
padding: 0.15em 0.35em;
border-radius: calc(var(--radius) - 2px);
font-weight: 500;
border: 1px solid hsl(var(--border) / 0.5);
}
footer { display: none !important; }
/* ========== MARKDOWN ========== */
.gradio-markdown {
color: hsl(var(--foreground)) !important;
font-size: 0.9375rem !important;
line-height: 1.6 !important;
}
.gradio-markdown:first-child {
margin-bottom: 2rem;
padding-bottom: 1.5rem;
border-bottom: 1px solid hsl(var(--border));
}
.gradio-markdown:last-child {
padding-top: 1.5rem;
border-top: 1px solid hsl(var(--border));
color: hsl(var(--muted-foreground)) !important;
}
.gradio-markdown a {
color: hsl(var(--foreground)) !important;
text-decoration: none;
border-bottom: 1px solid hsl(var(--border));
transition: border-color 0.2s ease;
}
.gradio-markdown a:hover {
border-color: hsl(var(--ring));
}
/* ========== LAYOUT ========== */
.main-grid {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 2rem;
}
@media (max-width: 768px) {
.main-grid { grid-template-columns: 1fr; }
}
/* ========== CARDS ========== */
.card {
background: hsl(var(--card));
border: 1px solid hsl(var(--card-border));
border-radius: var(--radius);
padding: 1.5rem;
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.3), 0 0 0 1px rgba(255, 255, 255, 0.02) inset;
}
/* ========== FORM ELEMENTS ========== */
.form-label {
font-size: 0.875rem;
font-weight: 500;
margin-bottom: 0.5rem;
display: block;
color: hsl(var(--foreground));
}
.input-textarea {
width: 100%;
min-height: 140px;
padding: 0.875rem;
font-size: 0.9375rem;
line-height: 1.6;
background: hsl(var(--background));
border: 1px solid hsl(var(--input));
border-radius: var(--radius);
color: hsl(var(--foreground));
transition: all 0.2s ease;
resize: vertical;
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2);
}
.input-textarea::placeholder {
color: hsl(var(--muted-foreground) / 0.7);
}
.input-textarea:focus {
outline: none;
border-color: hsl(var(--ring));
box-shadow: 0 0 0 3px hsl(var(--ring) / 0.1), 0 1px 2px rgba(0, 0, 0, 0.2);
background: hsl(var(--background) / 0.8);
}
/* ========== BUTTONS ========== */
.btn {
display: inline-flex;
align-items: center;
justify-content: center;
gap: 0.5rem;
font-size: 0.9375rem;
font-weight: 500;
padding: 0.625rem 1.25rem;
border-radius: var(--radius);
cursor: pointer;
transition: all 0.2s ease;
border: none;
}
.btn:focus-visible {
outline: none;
box-shadow: 0 0 0 2px hsl(var(--background)), 0 0 0 4px hsl(var(--ring));
}
.btn-primary {
background: hsl(var(--primary));
color: hsl(var(--primary-foreground));
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2), 0 0 0 1px rgba(255, 255, 255, 0.05) inset;
}
.btn-primary:hover {
opacity: 0.95;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.25), 0 0 0 1px rgba(255, 255, 255, 0.08) inset;
}
.btn-primary:active {
transform: translateY(1px);
}
.btn-primary:disabled {
opacity: 0.5;
cursor: not-allowed;
}
.btn-secondary {
background: hsl(var(--secondary));
color: hsl(var(--secondary-foreground));
border: 1px solid hsl(var(--border));
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2);
}
.btn-secondary:hover {
background: hsl(var(--secondary) / 0.8);
border-color: hsl(var(--muted-foreground) / 0.5);
}
.btn-secondary:active {
transform: translateY(1px);
}
/* ========== OUTPUT CONTAINER ========== */
.output-container {
min-height: 140px;
padding: 0.875rem;
border: 1px solid hsl(var(--input));
border-radius: var(--radius);
background: hsl(var(--background));
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.15), 0 0 0 1px rgba(255, 255, 255, 0.02) inset;
}
.output-text {
color: hsl(var(--foreground));
font-size: 0.9375rem;
line-height: 1.75;
margin: 0;
}
.placeholder-text {
color: hsl(var(--muted-foreground));
}
.highlight-keyword {
background: hsl(var(--accent) / 0.6);
color: hsl(var(--accent-foreground));
padding: 0.15em 0.35em;
border-radius: calc(var(--radius) - 2px);
font-weight: 500;
border: 1px solid hsl(var(--border) / 0.5);
}
/* ========== EXAMPLES ========== */
.examples-section {
padding: 1.5rem;
background: hsl(var(--card));
border: 1px solid hsl(var(--card-border));
border-radius: var(--radius);
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2), 0 0 0 1px rgba(255, 255, 255, 0.02) inset;
}
/* ========== SPACING UTILITIES ========== */
.mt-6 { margin-top: 1.5rem; }
.flex { display: flex; }
.gap-2 { gap: 0.5rem; }
"""
# =============================================================================
# Gradio Interface
# =============================================================================
with gr.Blocks(css=custom_css, title="Prompt Enhancer") as demo:
# Header
with gr.Row():
gr.Markdown("Transform your creative ideas into detailed, vivid prompts for AI image generation.")
# Main content - two column layout
with gr.Row(elem_classes=["main-grid"]):
# Input column
with gr.Column(elem_classes=["card"]):
gr.HTML('<label class="form-label">Input Prompt</label>')
input_text = gr.Textbox(
placeholder="Describe your image concept... e.g., fox, red tail, blue moon, clouds",
lines=5,
show_label=False,
autofocus=True,
container=False,
elem_classes=["input-textarea"]
)
with gr.Row(elem_classes=["flex gap-2 mt-6"]):
enhance_btn = gr.Button(
"Enhance Prompt",
variant="primary",
scale=2,
elem_classes=["btn", "btn-primary"]
)
clear_btn = gr.Button(
"Clear",
scale=1,
elem_classes=["btn", "btn-secondary"]
)
# Output column
with gr.Column(elem_classes=["card"]):
gr.HTML('<label class="form-label">Enhanced Prompt</label>')
output_html = gr.HTML(
value='<span class="placeholder-text">Your enhanced prompt will appear here</span>',
elem_classes=["output-container"]
)
raw_output = gr.Textbox(visible=False)
# Examples section
with gr.Column(elem_classes=["examples-section"]):
gr.Examples(
examples=[
["fox, red tail, blue moon, clouds"],
["room with french window, cozy morning vibes, minimal"],
["anime style, sunset, japan"]
],
inputs=input_text,
label="Examples"
)
# Footer
with gr.Row():
gr.Markdown(
"Powered by [PromptTuner](https://huggingface.co/shb777/PromptTuner-v0.1), "
"a finetuned gemma3-270M model specifically designed to enhance text prompts "
"for text-to-image generation."
)
# =============================================================================
# Event Handlers
# =============================================================================
enhance_btn.click(
fn=enhance_prompt,
inputs=[input_text, gr.State(False)],
outputs=[output_html, raw_output, enhance_btn, clear_btn]
)
clear_btn.click(
fn=lambda: (
"",
'<span class="placeholder-text">Your enhanced prompt will appear here</span>',
"",
gr.update(interactive=True),
gr.update(interactive=True)
),
inputs=None,
outputs=[input_text, output_html, raw_output, enhance_btn, clear_btn]
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(mcp_server=True)