import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel BASE_MODEL = "unsloth/gemma-3-270m-it" LORA_ADAPTER = "newadays/gemma_3_lora_ig_post" tokenizer = AutoTokenizer.from_pretrained(LORA_ADAPTER) base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.float32) model = PeftModel.from_pretrained(base_model, LORA_ADAPTER) model.eval() PLATFORM_PROMPTS = { "Instagram": ( "Generate a creative Instagram caption described in the context provided " "in triple backticks for a brand called {brand} located in {location} with emojis.\n\n" "Please follow the steps below to perform the task:\n" "1 - Add important information that will captivate readers.\n" "2 - Itemize the good features in a List\n" "3 - Include the location\n" "4 - Remember to add hashtags related to the product at the end of the caption in a separate line.\n" "5 - Not more than 150 words\n\n" "Please remember to follow these important guidelines:\n" "- Remember to use proper grammar throughout the caption\n" "- Keep the caption short and to the point.\n" "- Use a {tone} tone throughout the caption.\n" "- Create a Sense of Urgency throughout the caption.\n" "- Remember to itemize the features in a List\n\n" "```\n{description}\n```" ), "Facebook": ( "Generate a Facebook post described in the context provided " "in triple backticks for a brand called {brand} located in {location}.\n\n" "Please follow the steps below to perform the task:\n" "1 - Start with a question or hook that grabs attention.\n" "2 - Tell a short story or share a relatable scenario about the product or service.\n" "3 - Highlight key benefits naturally within the narrative.\n" "4 - Include a clear call to action (visit, call, book, comment).\n" "5 - Between 100 and 250 words.\n\n" "Please remember to follow these important guidelines:\n" "- Use a {tone}, conversational tone throughout.\n" "- Write in a way that encourages comments and shares.\n" "- End with a question to boost engagement.\n" "- Use emojis sparingly (2-4 max).\n\n" "```\n{description}\n```" ), "X": ( "Generate a tweet (X post) described in the context provided " "in triple backticks for a brand called {brand} located in {location}.\n\n" "Please follow these rules strictly:\n" "1 - MUST be under 280 characters total.\n" "2 - Be punchy and direct. No bullet lists.\n" "3 - Use 1 to 3 relevant hashtags.\n" "4 - Include a call to action if possible.\n\n" "Please remember to follow these important guidelines:\n" "- Use a {tone} tone.\n" "- Keep it concise — every word must earn its place.\n" "- No fluff, no filler.\n\n" "```\n{description}\n```" ), } EXAMPLES = [ ["Instagram", "Chernov Team Realtor", "Los Angeles, California", "a white house with the words just listed above it", "Enthusiastic"], ["Instagram", "Bloom & Petal Florist", "Austin, Texas", "a vibrant bouquet of sunflowers and roses for a summer wedding", "Warm"], ["Instagram", "IronForge Gym", "Miami, Florida", "a modern gym interior with free weights and a motivational wall mural", "Motivational"], ["Facebook", "Sweet Crumbs Bakery", "Portland, Oregon", "a display case full of freshly baked croissants and artisan sourdough loaves", "Friendly"], ["Facebook", "Chernov Team Realtor", "Los Angeles, California", "a modern luxury home with a swimming pool and palm trees in the backyard", "Professional"], ["X", "LaunchPad AI", "San Francisco, California", "a new AI-powered productivity tool that helps teams automate repetitive tasks", "Bold"], ["X", "Drip Coffee Co.", "Seattle, Washington", "a new single-origin Ethiopian pour-over now available at all locations", "Casual"], ] def generate_post(platform, brand, location, description, tone): if not all([brand.strip(), location.strip(), description.strip()]): return "Please fill in Brand Name, Location, and Content Description." system_content = PLATFORM_PROMPTS[platform].format( brand=brand.strip(), location=location.strip(), description=description.strip(), tone=tone.lower(), ) # Matches training data format: system prompt + empty user message messages = [ {"role": "system", "content": system_content}, {"role": "user", "content": ""}, ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) if text.startswith(""): text = text[len(""):] inputs = tokenizer(text, return_tensors="pt").to(model.device) input_len = inputs["input_ids"].shape[1] try: with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=512, temperature=1.0, top_p=0.95, top_k=64, do_sample=True, ) generated = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True) return generated.strip() except Exception as e: return f"Generation failed: {e}" with gr.Blocks(title="Social Media Post Generator") as demo: gr.Markdown( "# Social Media Post Generator\n" "Generate platform-optimized posts for **Instagram**, **Facebook**, and **X** " "powered by a fine-tuned [Gemma-3 270M](https://huggingface.co/newadays/gemma_3_lora_ig_post) model." ) with gr.Row(): with gr.Column(scale=1): platform = gr.Dropdown( choices=["Instagram", "Facebook", "X"], value="Instagram", label="Platform", ) brand = gr.Textbox(label="Brand Name", placeholder="e.g. Chernov Team Realtor") location = gr.Textbox(label="Location", placeholder="e.g. Los Angeles, California") description = gr.Textbox( label="Content Description", placeholder="Describe what the post is about...", lines=3, ) tone = gr.Dropdown( choices=["Enthusiastic", "Professional", "Friendly", "Warm", "Bold", "Casual", "Motivational"], value="Enthusiastic", label="Tone", ) generate_btn = gr.Button("Generate Post", variant="primary") with gr.Column(scale=1): output = gr.Textbox( label="Generated Post", lines=12, interactive=False, ) generate_btn.click( fn=generate_post, inputs=[platform, brand, location, description, tone], outputs=output, ) gr.Examples( examples=EXAMPLES, inputs=[platform, brand, location, description, tone], label="Click an example to fill the form, then hit Generate Post", ) if __name__ == "__main__": demo.launch( share=True, # cache_examples=True, show_error=True, # enable_queue=True # Process requests sequentially to avoid GPU OOM )