retowyss commited on
Commit
bce946a
·
verified ·
1 Parent(s): 1154198

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -32
app.py CHANGED
@@ -1,47 +1,57 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
 
5
  MODEL_NAME = "retowyss/PromptBridge-0.6b-Alpha"
6
 
7
- # Load model and tokenizer
 
 
 
 
8
  print("Loading model...")
 
 
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
 
 
10
  model = AutoModelForCausalLM.from_pretrained(
11
  MODEL_NAME,
12
  trust_remote_code=True,
13
- torch_dtype=torch.float32, # Use float32 for CPU compatibility
14
- device_map="cpu"
15
  )
16
  model.eval()
17
- print("Model loaded!")
18
 
 
19
  def generate_prompt(mode: str, user_prompt: str, temperature: float = 0.7, max_tokens: int = 512):
20
  """Generate prompt transformation."""
21
-
22
  # Map mode to system prompt
23
  system_prompts = {
24
  "Expand": "Expand the prompt.",
25
  "Compress to Sentence": "Compress the prompt into one sentence.",
26
  "Compress to Keywords": "Compress the prompt into keyword format."
27
  }
28
-
29
  system_prompt = system_prompts[mode]
30
-
31
  messages = [
32
  {"role": "system", "content": system_prompt},
33
  {"role": "user", "content": user_prompt}
34
  ]
35
-
36
  # Apply chat template
37
  text = tokenizer.apply_chat_template(
38
  messages,
39
  tokenize=False,
40
  add_generation_prompt=True
41
  )
42
-
43
- inputs = tokenizer(text, return_tensors="pt")
44
-
45
  # Generate
46
  with torch.no_grad():
47
  outputs = model.generate(
@@ -53,13 +63,13 @@ def generate_prompt(mode: str, user_prompt: str, temperature: float = 0.7, max_t
53
  pad_token_id=tokenizer.pad_token_id,
54
  eos_token_id=tokenizer.eos_token_id,
55
  )
56
-
57
  # Decode only the new tokens
58
  response = tokenizer.decode(
59
  outputs[0][inputs['input_ids'].shape[1]:],
60
  skip_special_tokens=True
61
  )
62
-
63
  return response.strip()
64
 
65
  # Example prompts
@@ -78,19 +88,19 @@ compression_examples = [
78
  with gr.Blocks(title="PromptBridge-0.6b-Alpha") as demo:
79
  gr.Markdown("""
80
  # 🌉 PromptBridge-0.6b-Alpha
81
-
82
  A specialized model for bidirectional prompt transformation for text-to-image generation.
83
-
84
  **Trained exclusively on prompts featuring single, adult humanoid subjects.**
85
-
86
  ### Modes:
87
  - **Expand**: Convert brief keywords into detailed image generation prompts
88
  - **Compress to Sentence**: Condense detailed prompts into a single flowing sentence
89
  - **Compress to Keywords**: Convert prompts into comma-separated keywords
90
-
91
  ⚠️ **Note**: May generate sensitive content (PG to R-rated). You are responsible for the output.
92
  """)
93
-
94
  with gr.Row():
95
  with gr.Column():
96
  mode = gr.Radio(
@@ -98,13 +108,13 @@ with gr.Blocks(title="PromptBridge-0.6b-Alpha") as demo:
98
  value="Expand",
99
  label="Mode"
100
  )
101
-
102
  user_input = gr.Textbox(
103
  lines=5,
104
  placeholder="Enter your prompt here...",
105
  label="Input Prompt"
106
  )
107
-
108
  with gr.Row():
109
  temperature = gr.Slider(
110
  minimum=0.1,
@@ -113,7 +123,7 @@ with gr.Blocks(title="PromptBridge-0.6b-Alpha") as demo:
113
  step=0.1,
114
  label="Temperature"
115
  )
116
-
117
  max_tokens = gr.Slider(
118
  minimum=64,
119
  maximum=1024,
@@ -121,18 +131,18 @@ with gr.Blocks(title="PromptBridge-0.6b-Alpha") as demo:
121
  step=64,
122
  label="Max Tokens"
123
  )
124
-
125
  submit_btn = gr.Button("Generate", variant="primary")
126
-
127
  with gr.Column():
128
  output = gr.Textbox(
129
  lines=15,
130
  label="Output"
131
  )
132
-
133
  # Examples
134
  gr.Markdown("### Examples")
135
-
136
  with gr.Tab("Expansion Examples"):
137
  gr.Examples(
138
  examples=expansion_examples,
@@ -141,7 +151,7 @@ with gr.Blocks(title="PromptBridge-0.6b-Alpha") as demo:
141
  fn=generate_prompt,
142
  cache_examples=False
143
  )
144
-
145
  with gr.Tab("Compression Examples"):
146
  gr.Examples(
147
  examples=compression_examples,
@@ -150,22 +160,22 @@ with gr.Blocks(title="PromptBridge-0.6b-Alpha") as demo:
150
  fn=generate_prompt,
151
  cache_examples=False
152
  )
153
-
154
  gr.Markdown("""
155
  ### About
156
-
157
  **Model**: [retowyss/PromptBridge-0.6b-Alpha](https://huggingface.co/retowyss/PromptBridge-0.6b-Alpha)
158
-
159
  **Training Data**: ~300K synthetic prompt pairs (PG to R-rated, X-rated content removed)
160
-
161
  **Limitations**:
162
  - Optimized for human subjects only (performance on other subjects untested)
163
  - May generate prompts exceeding typical image model token limits
164
  - Cannot perform general instruction-following or reasoning tasks
165
-
166
  **License**: Apache 2.0
167
  """)
168
-
169
  # Connect the button
170
  submit_btn.click(
171
  fn=generate_prompt,
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import spaces
5
 
6
  MODEL_NAME = "retowyss/PromptBridge-0.6b-Alpha"
7
 
8
+ # Detect device
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+ DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
11
+
12
+ print(f"Using device: {DEVICE}")
13
  print("Loading model...")
14
+
15
+ # Load tokenizer
16
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
17
+
18
+ # Load model with appropriate settings
19
  model = AutoModelForCausalLM.from_pretrained(
20
  MODEL_NAME,
21
  trust_remote_code=True,
22
+ torch_dtype=DTYPE,
23
+ device_map="auto" if DEVICE == "cuda" else "cpu"
24
  )
25
  model.eval()
26
+ print(f"Model loaded on {DEVICE}!")
27
 
28
+ @spaces.GPU(duration=60) # Allocate GPU for 60 seconds
29
  def generate_prompt(mode: str, user_prompt: str, temperature: float = 0.7, max_tokens: int = 512):
30
  """Generate prompt transformation."""
31
+
32
  # Map mode to system prompt
33
  system_prompts = {
34
  "Expand": "Expand the prompt.",
35
  "Compress to Sentence": "Compress the prompt into one sentence.",
36
  "Compress to Keywords": "Compress the prompt into keyword format."
37
  }
38
+
39
  system_prompt = system_prompts[mode]
40
+
41
  messages = [
42
  {"role": "system", "content": system_prompt},
43
  {"role": "user", "content": user_prompt}
44
  ]
45
+
46
  # Apply chat template
47
  text = tokenizer.apply_chat_template(
48
  messages,
49
  tokenize=False,
50
  add_generation_prompt=True
51
  )
52
+
53
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
54
+
55
  # Generate
56
  with torch.no_grad():
57
  outputs = model.generate(
 
63
  pad_token_id=tokenizer.pad_token_id,
64
  eos_token_id=tokenizer.eos_token_id,
65
  )
66
+
67
  # Decode only the new tokens
68
  response = tokenizer.decode(
69
  outputs[0][inputs['input_ids'].shape[1]:],
70
  skip_special_tokens=True
71
  )
72
+
73
  return response.strip()
74
 
75
  # Example prompts
 
88
  with gr.Blocks(title="PromptBridge-0.6b-Alpha") as demo:
89
  gr.Markdown("""
90
  # 🌉 PromptBridge-0.6b-Alpha
91
+
92
  A specialized model for bidirectional prompt transformation for text-to-image generation.
93
+
94
  **Trained exclusively on prompts featuring single, adult humanoid subjects.**
95
+
96
  ### Modes:
97
  - **Expand**: Convert brief keywords into detailed image generation prompts
98
  - **Compress to Sentence**: Condense detailed prompts into a single flowing sentence
99
  - **Compress to Keywords**: Convert prompts into comma-separated keywords
100
+
101
  ⚠️ **Note**: May generate sensitive content (PG to R-rated). You are responsible for the output.
102
  """)
103
+
104
  with gr.Row():
105
  with gr.Column():
106
  mode = gr.Radio(
 
108
  value="Expand",
109
  label="Mode"
110
  )
111
+
112
  user_input = gr.Textbox(
113
  lines=5,
114
  placeholder="Enter your prompt here...",
115
  label="Input Prompt"
116
  )
117
+
118
  with gr.Row():
119
  temperature = gr.Slider(
120
  minimum=0.1,
 
123
  step=0.1,
124
  label="Temperature"
125
  )
126
+
127
  max_tokens = gr.Slider(
128
  minimum=64,
129
  maximum=1024,
 
131
  step=64,
132
  label="Max Tokens"
133
  )
134
+
135
  submit_btn = gr.Button("Generate", variant="primary")
136
+
137
  with gr.Column():
138
  output = gr.Textbox(
139
  lines=15,
140
  label="Output"
141
  )
142
+
143
  # Examples
144
  gr.Markdown("### Examples")
145
+
146
  with gr.Tab("Expansion Examples"):
147
  gr.Examples(
148
  examples=expansion_examples,
 
151
  fn=generate_prompt,
152
  cache_examples=False
153
  )
154
+
155
  with gr.Tab("Compression Examples"):
156
  gr.Examples(
157
  examples=compression_examples,
 
160
  fn=generate_prompt,
161
  cache_examples=False
162
  )
163
+
164
  gr.Markdown("""
165
  ### About
166
+
167
  **Model**: [retowyss/PromptBridge-0.6b-Alpha](https://huggingface.co/retowyss/PromptBridge-0.6b-Alpha)
168
+
169
  **Training Data**: ~300K synthetic prompt pairs (PG to R-rated, X-rated content removed)
170
+
171
  **Limitations**:
172
  - Optimized for human subjects only (performance on other subjects untested)
173
  - May generate prompts exceeding typical image model token limits
174
  - Cannot perform general instruction-following or reasoning tasks
175
+
176
  **License**: Apache 2.0
177
  """)
178
+
179
  # Connect the button
180
  submit_btn.click(
181
  fn=generate_prompt,