| | import os |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | import gradio as gr |
| |
|
| | |
| | MODEL_ID = "google/gemma-1.1-2b-it" |
| | HF_TOKEN = os.getenv("HF_TOKEN") |
| | MAX_TOKENS = 80 |
| |
|
| | def load_model(): |
| | """Simplified model loading that works in Spaces""" |
| | print("🔄 Loading model...") |
| | |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) |
| | |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | MODEL_ID, |
| | torch_dtype=torch.float32, |
| | token=HF_TOKEN |
| | ).to('cpu') |
| | |
| | print("✅ Model loaded successfully!") |
| | return tokenizer, model |
| |
|
| | tokenizer, model = load_model() |
| |
|
| | def predict(topic): |
| | """Memory-safe generation""" |
| | try: |
| | prompt = f"Create a short script about {topic}:\n1) Hook\n2) Point\n3) CTA\n\nScript:" |
| | inputs = tokenizer(prompt, return_tensors="pt").to('cpu') |
| | |
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=MAX_TOKENS, |
| | temperature=0.7 |
| | ) |
| | |
| | return tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | |
| | except Exception as e: |
| | return f"Error: {str(e)}" |
| |
|
| | |
| | gr.Interface( |
| | fn=predict, |
| | inputs=gr.Textbox(label="Topic"), |
| | outputs=gr.Textbox(label="Script", lines=4), |
| | api_name="predict" |
| | ).launch(server_name="0.0.0.0") |