Prompteador / app.py
Malaji71's picture
Update app.py
fe8a377 verified
# app.py — PromptCraft: Refinamiento Estructural de Prompts (versión equilibrada)
import gradio as gr
import os
import time
import logging
from typing import Optional, Tuple
from PIL import Image
from agent import ImprovedSemanticAgent
from huggingface_hub import InferenceClient
from transformers import pipeline
from openai import OpenAI, APIError, Timeout
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LlamaRefiner:
def __init__(self):
hf_token = os.getenv("PS")
if not hf_token:
raise ValueError("Secret 'PS' (HF_TOKEN) no encontrado.")
self.hf_client = InferenceClient(api_key=hf_token)
self.agent = ImprovedSemanticAgent()
if not self.agent.is_ready:
init_msg = self.agent._lazy_init()
logger.info(f"Inicialización del agente: {init_msg}")
if not self.agent.is_ready:
logger.error("❌ Agente NO está listo.")
logger.info("🚀 Cargando traductor local (es → en)...")
self.translator = pipeline(
"translation_es_to_en",
model="Helsinki-NLP/opus-mt-es-en",
device=-1
)
logger.info("✅ Traductor local listo.")
def translate_to_english(self, text: str) -> str:
if not text.strip():
return text
try:
result = self.translator(text, max_length=250, clean_up_tokenization_spaces=True)
raw_translation = result[0]['translation_text'].strip()
except Exception as e:
logger.warning(f"Traducción local fallida: {e}. Usando texto original.")
raw_translation = text
user_text_lower = text.lower()
output = raw_translation
if any(kw in user_text_lower for kw in ["llamas", "ardiendo", "quem", "incendi", "fuego"]):
output = output.replace("fiery", "on fire")
if not any(term in output.lower() for term in ["on fire", "burning", "in flames", "ablaze", "aflame"]):
output = output + " on fire"
if any(kw in user_text_lower for kw in ["oro", "dorado"]):
if "golden" not in output.lower() and "gold" not in output.lower():
if any(w in output.lower() for w in ["statue", "sculpture", "figure"]):
output = output + " made of gold"
else:
output = output + " golden"
if any(kw in user_text_lower for kw in ["congelado", "hielo", "helado", "ice"]):
if not any(term in output.lower() for term in ["frozen", "ice", "icy"]):
output = output + " frozen"
return output.strip()
def retrieve_similar_examples(self, user_prompt_en: str, category: str = "auto", k: int = 6) -> list:
if not self.agent.is_ready:
return []
try:
query_embedding = self.agent.embedding_model.encode([user_prompt_en], convert_to_numpy=True, normalize_embeddings=True)[0]
query_embedding = query_embedding.astype('float32').reshape(1, -1)
distances, indices = self.agent.index.search(query_embedding, 50)
candidates = []
for idx in indices[0]:
if 0 <= idx < len(self.agent.indexed_examples):
ex = self.agent.indexed_examples[idx]
caption = ex.get('caption', '')
ex_category = ex.get('category', 'auto')
if isinstance(caption, str) and len(caption) > 10:
if category == "auto" or ex_category == category:
candidates.append((idx, caption, ex_category))
if not candidates:
return []
if len(candidates) <= k:
return [cap for _, cap, _ in candidates]
candidate_texts = [cap for _, cap, _ in candidates]
pairs = [[user_prompt_en, cand] for cand in candidate_texts]
scores = self.agent.reranker.predict(pairs)
scored = [(candidates[i][1], scores[i]) for i in range(len(candidates))]
scored.sort(key=lambda x: x[1], reverse=True)
top_examples = [ex for ex, _ in scored[:k]]
return top_examples
except Exception as e:
logger.error(f"Error en recuperación: {e}")
try:
return [
self.agent.indexed_examples[idx]['caption']
for idx in indices[0][:k]
if 0 <= idx < len(self.agent.indexed_examples)
]
except:
return []
def _clean_output(self, text: str) -> str:
text = text.strip()
if text.startswith(("Here is", "Final:", "Output:", '"', "'")):
text = text.split(":", 1)[-1].strip().strip("\"'")
return text
def refine_with_llm(self, user_prompt: str, category: str = "auto") -> Tuple[str, str, list]:
user_prompt_en = self.translate_to_english(user_prompt)
examples = self.retrieve_similar_examples(user_prompt_en, category=category, k=6)
if not examples:
fallbacks = {
"entity": [
"an elderly maya man weaving a hammock under a ceiba tree, golden hour light filtering through leaves, Antigua Guatemala setting, hyperrealistic style",
"a young indigenous woman in traditional Kekchi attire by Lake Atitlán, morning mist, volcano backdrop, soft natural light, documentary photography"
],
"style": [
"oil painting of a forest in autumn, warm amber and crimson tones, impasto brushstrokes, style of Vincent van Gogh",
"cyberpunk cityscape at night, neon reflections on wet streets, cinematic lighting, style of Blade Runner 2049"
],
"composition": [
"a lone wolf on a snowy mountain peak, northern lights in the sky, wind blowing snow, rule of thirds composition, photorealistic wildlife"
],
"imaginative": [
"a floating island with ancient ruins, waterfalls cascading into clouds, golden hour, fantasy concept art, highly detailed"
],
"text": [
"minimalist typography design, the word 'LIBERTAD' in bold sans-serif, high contrast black on white, professional layout"
],
"auto": [
"an elderly maya man weaving a hammock under a ceiba tree, golden hour light filtering through leaves, Antigua Guatemala setting, hyperrealistic style, intricate textures of rope and bark",
"a cyberpunk street at night in Tokyo, neon signs reflecting on wet pavement, rain mist in air, distant flying cars, cinematic wide shot, Blade Runner atmosphere",
"a library interior with tall oak bookshelves, sunbeams through stained glass windows, dust particles floating, oil painting style, warm amber tones, masterpiece",
"a lone wolf howling on a snowy mountain peak, northern lights in the sky, wind blowing snow, photorealistic wildlife photography, 8k detailed fur",
"a steampunk airship floating above Victorian London, copper pipes and brass gears, cloudy sky, detailed machinery, concept art by Jakub Rozalski",
"a young woman in traditional Kekchi attire standing by Lake Atitlán, morning mist, volcano backdrop, soft natural light, documentary photography style"
]
}
examples = fallbacks.get(category, fallbacks["auto"])
logger.warning(f"No se encontraron ejemplos para categoría '{category}'. Usando fallback.")
# 🔄 System message flexible pero con énfasis en literalidad
system_message = (
"You are a prompt engineering analyst for diffusion models (Midjourney, FLUX, SDXL). "
"Analyze the DESCRIPTIVE GRAMMAR (word order, phrasing, element sequence) used in the reference prompts below. "
"Reconstruct the user's concept using that exact same descriptive logic. "
"Do NOT follow a predefined template (e.g. subject→lighting→style). "
"Do NOT invent elements not implied by the user. "
"If the user specifies that something is 'on fire', ensure that the flames are on that object or creature itself — not merely in the background or environment. "
"Preserve the user's core intent exactly. "
"Output ONLY the final prompt in English. No explanations, no markdown."
)
# 🧠 Mensaje de usuario: natural, pero con señal clara de literalidad
core_user_message = f"User concept:\n{user_prompt_en}"
if any(term in user_prompt_en.lower() for term in ["on fire", "burning", "ablaze", "aflame"]):
core_user_message += "\n\n⚠️ Note: The subject (e.g., horse, dragon) must be physically on fire with visible flames."
user_message = core_user_message + "\n\nReference prompts (observe their descriptive grammar):\n" + "\n".join(examples)
try:
client = OpenAI(
base_url="https://router.huggingface.co/v1",
api_key=os.getenv("PS")
)
completion = client.chat.completions.create(
model="meta-llama/Llama-3.2-3B-Instruct:together",
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
],
max_tokens=250,
temperature=0.2,
timeout=30.0
)
refined = self._clean_output(completion.choices[0].message.content)
info = f"🧠 Refinado con Llama-3.2-3B vía Together (HF Router, {len(examples)} ejemplos, categoría: {category})."
return refined, info, examples
except (APIError, Timeout, Exception) as e1:
logger.error(f"Error con Together (HF Router): {e1}")
try:
completion = self.hf_client.chat.completions.create(
model="meta-llama/Llama-3.2-3B-Instruct",
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
],
max_tokens=250,
temperature=0.2
)
refined = self._clean_output(completion.choices[0].message.content)
info = f"🧠 Fallback: HF (Hyperbolic, {len(examples)} ejemplos)."
return refined, info, examples
except Exception as e2:
logger.error(f"También falló Hyperbolic: {e2}")
enhanced_prompt, _ = self.agent.enhance_prompt(user_prompt_en, category=category)
return enhanced_prompt.strip(), f"⚠️ LLMs no disponibles. Usando enriquecimiento semántico (categoría: {category}).", examples
class SDXLGenerator:
def __init__(self):
hf_token = os.getenv("PS")
if not hf_token:
raise ValueError("Secret 'PS' (HF_TOKEN) no encontrado.")
self.client = InferenceClient(api_key=hf_token)
def generate_image(self, prompt: str, width: int = 1024, height: int = 1024) -> Tuple[Optional[str], str]:
try:
image = self.client.text_to_image(
prompt=prompt,
model="stabilityai/stable-diffusion-xl-base-1.0",
width=width,
height=height
)
output_path = f"/tmp/image_{int(time.time())}.png"
image.save(output_path)
return output_path, "Imagen generada con éxito."
except Exception as e:
return None, f"Error en generación: {str(e)}"
def create_interface():
try:
refiner = LlamaRefiner()
generator = SDXLGenerator()
except Exception as e:
refiner = None
generator = None
logger.error(f"Inicialización fallida: {e}")
def refine_prompt_only(prompt: str, category_es: str, progress=gr.Progress()):
if not prompt.strip():
return "", "", "Prompt vacío."
if refiner is None:
return "", "", "Servicios no disponibles."
progress(0.2, desc="🌍 Traduciendo y mejorando...")
category_map = {
"Automática": "auto",
"Entidad": "entity",
"Composición": "composition",
"Estilo artístico": "style",
"Imaginativo": "imaginative",
"Texto": "text"
}
category_en = category_map.get(category_es, "auto")
refined, info, examples = refiner.refine_with_llm(prompt, category_en)
examples_text = "\n".join(f"{i+1}. {ex}" for i, ex in enumerate(examples)) if examples else "Ninguno"
status = f"Prompt refinado: {refined}\n{info}"
return refined, examples_text, status
def generate_image_only(refined_prompt: str, aspect_ratio: str, progress=gr.Progress()):
if not refined_prompt.strip():
return None, "❌ No hay prompt refinado. Primero haz clic en 'Refinar prompt'."
if generator is None:
return None, "❌ Generador no inicializado."
aspect_ratios = {
"1:1": (1024, 1024),
"16:9": (1344, 768),
"9:16": (768, 1344),
"4:3": (1152, 896),
"3:4": (896, 1152),
"21:9": (1536, 640),
"9:21": (640, 1536),
}
width, height = aspect_ratios.get(aspect_ratio, (1024, 1024))
progress(0.5, desc="🎨 Generando imagen (puede tardar 10-20s)...")
try:
image_path, gen_msg = generator.generate_image(refined_prompt, width, height)
return image_path, gen_msg
except Exception as e:
error_msg = f"❌ Error al generar: {str(e)}"
logger.error(error_msg)
return None, error_msg
CATEGORY_CHOICES_ES = ["Automática", "Entidad", "Composición", "Estilo artístico", "Imaginativo", "Texto"]
with gr.Blocks(title="PromptCraft: Refinamiento Estructural de Prompts") as demo:
gr.HTML("""
<div style="text-align: center; padding: 20px; background: #000000; border-bottom: 1px solid #333;">
<h1 style="color: #ffffff; font-size: 2.2em; font-weight: 400; margin: 0;">
PromptCraft: Refinamiento Estructural de Prompts
</h1>
<p style="color: #aaaaaa; font-size: 0.95em; margin: 12px 0 0 0; max-width: 800px; margin-left: auto; margin-right: auto; padding: 0 10px; line-height: 1.5;">
Esta herramienta genera prompts optimizados para modelos de difusión (como Midjourney, Flux o SDXL) mediante el análisis estructural de un dataset de 100.000 prompts.<br>
El usuario introduce su idea en castellano.<br>
El sistema traduce ese texto a inglés.<br>
Recupera los prompts semánticamente más cercanos en el dataset.<br>
Usa un modelo de lenguaje (Llama-3.2-3B) para reconstruir el prompt del usuario a partir de la estructura descriptiva de los prompts del dataset (sujeto → contexto → entorno → iluminación → estilo).<br>
Entrega un prompt en inglés listo para generación de imagen.<br>
No añade elementos no sugeridos por el usuario. Su objetivo es la coherencia estructural, no la invención creativa.
</p>
</div>
""")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Tu idea (en castellano)", lines=3, placeholder="Ej: un caballo en llamas galopando en un bosque...")
category_es = gr.Dropdown(label="Categoría", choices=CATEGORY_CHOICES_ES, value="Automática")
aspect = gr.Dropdown(label="Proporción", choices=["1:1", "16:9", "9:16", "4:3", "3:4", "21:9", "9:21"], value="1:1")
refine_btn = gr.Button("🔄 Refinar prompt", variant="secondary")
generate_btn = gr.Button("🎨 Generar imagen", variant="primary")
with gr.Column():
refined_output = gr.Textbox(label="Prompt refinado (inglés)", interactive=False, lines=3)
image_out = gr.Image(label="Imagen", type="filepath", height=450)
examples_out = gr.Textbox(label="Ejemplos del dataset (para análisis)", interactive=False, lines=6)
status_out = gr.Textbox(label="Estado", interactive=False, lines=4)
gr.HTML("""
<div style="text-align: center; padding: 12px; margin-top: 15px; border-top: 1px solid #333; color: #999; font-size: 0.95em;">
Creado por Angel E. Pariente 🇬🇹 • Sobre una idea de Nacho Ravinovich 🇦🇷
</div>
""")
refine_btn.click(
fn=refine_prompt_only,
inputs=[prompt_input, category_es],
outputs=[refined_output, examples_out, status_out],
show_progress=True
)
generate_btn.click(
fn=generate_image_only,
inputs=[refined_output, aspect],
outputs=[image_out, status_out],
show_progress=True
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)