File size: 1,502 Bytes
739eac7
e0f8fd3
3f73899
6395efd
e0f8fd3
7fbfabe
 
75f4ca1
7fbfabe
f5025f1
1a52249
7fbfabe
 
d535050
7fbfabe
1a52249
7fbfabe
 
469b10d
 
 
 
7fbfabe
f5025f1
469b10d
1a52249
 
 
3f73899
cf246af
f5025f1
1a52249
7fbfabe
 
1a52249
 
 
 
7fbfabe
1a52249
 
 
 
 
d535050
6395efd
469b10d
cf246af
 
f5025f1
469b10d
f5025f1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr

# Configuration
MODEL_ID = "google/gemma-1.1-2b-it"  # Using smaller 2B version
HF_TOKEN = os.getenv("HF_TOKEN")
MAX_TOKENS = 80  # Conservative limit

def load_model():
    """Simplified model loading that works in Spaces"""
    print("🔄 Loading model...")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
    
    # Explicit CPU-only loading
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float32,  # Required for CPU
        token=HF_TOKEN
    ).to('cpu')  # Explicit CPU placement
    
    print("✅ Model loaded successfully!")
    return tokenizer, model

tokenizer, model = load_model()

def predict(topic):
    """Memory-safe generation"""
    try:
        prompt = f"Create a short script about {topic}:\n1) Hook\n2) Point\n3) CTA\n\nScript:"
        inputs = tokenizer(prompt, return_tensors="pt").to('cpu')
        
        outputs = model.generate(
            **inputs,
            max_new_tokens=MAX_TOKENS,
            temperature=0.7
        )
        
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    except Exception as e:
        return f"Error: {str(e)}"

# Minimal interface
gr.Interface(
    fn=predict,
    inputs=gr.Textbox(label="Topic"),
    outputs=gr.Textbox(label="Script", lines=4),
    api_name="predict"
).launch(server_name="0.0.0.0")