Gbenga
added peft
b6123a2
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
BASE_MODEL = "unsloth/gemma-3-270m-it"
LORA_ADAPTER = "newadays/gemma_3_lora_ig_post"
tokenizer = AutoTokenizer.from_pretrained(LORA_ADAPTER)
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.float32)
model = PeftModel.from_pretrained(base_model, LORA_ADAPTER)
model.eval()
PLATFORM_PROMPTS = {
"Instagram": (
"Generate a creative Instagram caption described in the context provided "
"in triple backticks for a brand called {brand} located in {location} with emojis.\n\n"
"Please follow the steps below to perform the task:\n"
"1 - Add important information that will captivate readers.\n"
"2 - Itemize the good features in a List\n"
"3 - Include the location\n"
"4 - Remember to add hashtags related to the product at the end of the caption in a separate line.\n"
"5 - Not more than 150 words\n\n"
"Please remember to follow these important guidelines:\n"
"- Remember to use proper grammar throughout the caption\n"
"- Keep the caption short and to the point.\n"
"- Use a {tone} tone throughout the caption.\n"
"- Create a Sense of Urgency throughout the caption.\n"
"- Remember to itemize the features in a List\n\n"
"```\n{description}\n```"
),
"Facebook": (
"Generate a Facebook post described in the context provided "
"in triple backticks for a brand called {brand} located in {location}.\n\n"
"Please follow the steps below to perform the task:\n"
"1 - Start with a question or hook that grabs attention.\n"
"2 - Tell a short story or share a relatable scenario about the product or service.\n"
"3 - Highlight key benefits naturally within the narrative.\n"
"4 - Include a clear call to action (visit, call, book, comment).\n"
"5 - Between 100 and 250 words.\n\n"
"Please remember to follow these important guidelines:\n"
"- Use a {tone}, conversational tone throughout.\n"
"- Write in a way that encourages comments and shares.\n"
"- End with a question to boost engagement.\n"
"- Use emojis sparingly (2-4 max).\n\n"
"```\n{description}\n```"
),
"X": (
"Generate a tweet (X post) described in the context provided "
"in triple backticks for a brand called {brand} located in {location}.\n\n"
"Please follow these rules strictly:\n"
"1 - MUST be under 280 characters total.\n"
"2 - Be punchy and direct. No bullet lists.\n"
"3 - Use 1 to 3 relevant hashtags.\n"
"4 - Include a call to action if possible.\n\n"
"Please remember to follow these important guidelines:\n"
"- Use a {tone} tone.\n"
"- Keep it concise — every word must earn its place.\n"
"- No fluff, no filler.\n\n"
"```\n{description}\n```"
),
}
EXAMPLES = [
["Instagram", "Chernov Team Realtor", "Los Angeles, California",
"a white house with the words just listed above it", "Enthusiastic"],
["Instagram", "Bloom & Petal Florist", "Austin, Texas",
"a vibrant bouquet of sunflowers and roses for a summer wedding", "Warm"],
["Instagram", "IronForge Gym", "Miami, Florida",
"a modern gym interior with free weights and a motivational wall mural", "Motivational"],
["Facebook", "Sweet Crumbs Bakery", "Portland, Oregon",
"a display case full of freshly baked croissants and artisan sourdough loaves", "Friendly"],
["Facebook", "Chernov Team Realtor", "Los Angeles, California",
"a modern luxury home with a swimming pool and palm trees in the backyard", "Professional"],
["X", "LaunchPad AI", "San Francisco, California",
"a new AI-powered productivity tool that helps teams automate repetitive tasks", "Bold"],
["X", "Drip Coffee Co.", "Seattle, Washington",
"a new single-origin Ethiopian pour-over now available at all locations", "Casual"],
]
def generate_post(platform, brand, location, description, tone):
if not all([brand.strip(), location.strip(), description.strip()]):
return "Please fill in Brand Name, Location, and Content Description."
system_content = PLATFORM_PROMPTS[platform].format(
brand=brand.strip(),
location=location.strip(),
description=description.strip(),
tone=tone.lower(),
)
# Matches training data format: system prompt + empty user message
messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": ""},
]
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
if text.startswith("<bos>"):
text = text[len("<bos>"):]
inputs = tokenizer(text, return_tensors="pt").to(model.device)
input_len = inputs["input_ids"].shape[1]
try:
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=1.0,
top_p=0.95,
top_k=64,
do_sample=True,
)
generated = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True)
return generated.strip()
except Exception as e:
return f"Generation failed: {e}"
with gr.Blocks(title="Social Media Post Generator") as demo:
gr.Markdown(
"# Social Media Post Generator\n"
"Generate platform-optimized posts for **Instagram**, **Facebook**, and **X** "
"powered by a fine-tuned [Gemma-3 270M](https://huggingface.co/newadays/gemma_3_lora_ig_post) model."
)
with gr.Row():
with gr.Column(scale=1):
platform = gr.Dropdown(
choices=["Instagram", "Facebook", "X"],
value="Instagram",
label="Platform",
)
brand = gr.Textbox(label="Brand Name", placeholder="e.g. Chernov Team Realtor")
location = gr.Textbox(label="Location", placeholder="e.g. Los Angeles, California")
description = gr.Textbox(
label="Content Description",
placeholder="Describe what the post is about...",
lines=3,
)
tone = gr.Dropdown(
choices=["Enthusiastic", "Professional", "Friendly", "Warm", "Bold", "Casual", "Motivational"],
value="Enthusiastic",
label="Tone",
)
generate_btn = gr.Button("Generate Post", variant="primary")
with gr.Column(scale=1):
output = gr.Textbox(
label="Generated Post",
lines=12,
interactive=False,
)
generate_btn.click(
fn=generate_post,
inputs=[platform, brand, location, description, tone],
outputs=output,
)
gr.Examples(
examples=EXAMPLES,
inputs=[platform, brand, location, description, tone],
label="Click an example to fill the form, then hit Generate Post",
)
if __name__ == "__main__":
demo.launch(
share=True,
# cache_examples=True,
show_error=True,
# enable_queue=True # Process requests sequentially to avoid GPU OOM
)