boo4blue commited on
Commit
e051aaf
·
verified ·
1 Parent(s): a3465da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +239 -257
app.py CHANGED
@@ -1,304 +1,286 @@
1
- # app.py
2
- # Universal AI for Hugging Face Spaces — text + optional image, memory, system prompt, and generation controls.
3
- # Works with both the Gradio UI and the Hugging Face Inference API.
4
- #
5
- # Inference API payloads:
6
- # - Simple (string):
7
- # { "inputs": "Explain transformers in simple terms." }
8
- #
9
- # - Universal (JSON):
10
- # {
11
- # "mode": "chat",
12
- # "inputs": "Describe this image and write a tweet about it.",
13
- # "image": "<base64-encoded-image-optional>",
14
- # "options": { "temperature": 0.7, "max_new_tokens": 256 },
15
- # "system": "You are a concise, tactical assistant.",
16
- # "reset": false
17
- # }
18
-
19
  import os
20
- import io
21
- import json
22
- import base64
23
- from collections import deque
24
-
25
- from PIL import Image
26
  import gradio as gr
27
 
28
- from transformers import pipeline
29
-
30
- # -----------------------------
31
- # Model choices (tune as needed)
32
- # -----------------------------
33
- TEXT_MODEL = os.getenv("TEXT_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
34
- # Pick a lightweight image caption model so it runs on free tiers
35
- IMAGE_CAPTION_MODEL = os.getenv("IMAGE_MODEL", "nlpconnect/vit-gpt2-image-captioning")
36
 
37
  # -----------------------------
38
- # Load pipelines
39
  # -----------------------------
40
- text_gen = pipeline(
41
- "text-generation",
42
- model=TEXT_MODEL,
43
- trust_remote_code=True,
44
- device_map="auto"
45
  )
 
 
46
 
47
- # Lazily load image captioning when first used (saves cold start)
48
- _image_captioner = None
49
- def get_image_captioner():
50
- global _image_captioner
51
- if _image_captioner is None:
52
- _image_captioner = pipeline("image-to-text", model=IMAGE_CAPTION_MODEL, device_map="auto")
53
- return _image_captioner
54
 
55
  # -----------------------------
56
- # Memory and prompting
57
  # -----------------------------
58
- # Keep a short rolling memory of turns: [(user, assistant), ...]
59
- MEMORY_MAX_TURNS = int(os.getenv("MEMORY_TURNS", "6"))
60
- memory = deque(maxlen=MEMORY_MAX_TURNS)
61
-
62
- DEFAULT_SYSTEM_PROMPT = (
63
- "You are UniversalAI: a concise, capable, and adaptive assistant. "
64
- "Always be clear, practical, and accurate. If tools are unavailable, say so briefly then proceed with your best reasoning. "
65
- "Use step-by-step explanations only when they add value."
 
 
 
66
  )
67
 
68
- def build_prompt(user_msg: str, system_prompt: str) -> str:
69
- # Construct a clean, instruction-tuned style prompt.
70
- # Mistral Instruct can respond well to plain text, but [INST] tags often help.
71
- lines = []
72
- sys = system_prompt.strip() if system_prompt else DEFAULT_SYSTEM_PROMPT
73
- lines.append(f"<<SYS>>\n{sys}\n<</SYS>>")
74
- for u, a in list(memory):
75
- lines.append(f"[INST] {u.strip()} [/INST]\n{a.strip()}")
 
 
 
 
 
76
  lines.append(f"[INST] {user_msg.strip()} [/INST]\n")
77
  return "\n".join(lines)
78
 
79
  # -----------------------------
80
- # Utilities
81
  # -----------------------------
82
- def ensure_pil_image(img_input):
83
- # Handles either Gradio image (PIL) or base64 string
84
- if img_input is None:
85
- return None
86
- if isinstance(img_input, Image.Image):
87
- return img_input
88
- if isinstance(img_input, str):
89
- try:
90
- # If it's a data URL, strip the prefix
91
- if img_input.startswith("data:"):
92
- img_input = img_input.split(",", 1)[1]
93
- data = base64.b64decode(img_input)
94
- return Image.open(io.BytesIO(data)).convert("RGB")
95
- except Exception:
96
- return None
97
- return None
98
-
99
- def caption_image(pil_img):
100
- cap = get_image_captioner()
101
- try:
102
- result = cap(pil_img)
103
- if isinstance(result, list) and len(result) and "generated_text" in result[0]:
104
- return result[0]["generated_text"]
105
- # Some image-to-text pipelines return a string directly
106
- if isinstance(result, str):
107
- return result
108
- except Exception as e:
109
- return f"(Image captioning failed: {e})"
110
- return "(No caption generated)"
111
-
112
- def generate_text(prompt, temperature=0.7, max_new_tokens=256, top_p=0.9, do_sample=True):
113
- out = text_gen(
114
- prompt,
115
- temperature=float(temperature),
116
  max_new_tokens=int(max_new_tokens),
117
- top_p=float(top_p),
118
- do_sample=bool(do_sample),
119
- pad_token_id=50256 # safe default for many GPT-like models
120
  )
121
- # Pipeline returns a list of dicts with 'generated_text'
122
- return out[0]["generated_text"]
123
 
124
- def extract_assistant_reply(full_generated_text: str, user_prompt: str) -> str:
125
- # Heuristic: get only the text after the final user [INST] block.
126
- # If tags not found, return the full generated text.
127
- try:
128
- marker = f"[INST] {user_prompt.strip()} [/INST]"
129
- if marker in full_generated_text:
130
- return full_generated_text.split(marker, 1)[-1].strip()
131
- return full_generated_text.strip()
132
- except Exception:
133
- return full_generated_text.strip()
134
 
135
  # -----------------------------
136
- # Core handler (works for both UI and API)
137
  # -----------------------------
138
- def handle_request(
139
- user_input: str = "",
140
- image_input=None,
141
- temperature: float = 0.7,
142
- max_new_tokens: int = 256,
143
- system_prompt: str = DEFAULT_SYSTEM_PROMPT,
144
- reset_memory: bool = False
145
- ):
146
- # Reset memory if requested
147
- if reset_memory:
148
- memory.clear()
149
-
150
- # If image exists, caption it and augment the user input
151
- pil_img = ensure_pil_image(image_input)
152
- vision_context = ""
153
- if pil_img is not None:
154
- caption = caption_image(pil_img)
155
- vision_context = f"\n[Image context]: {caption}"
156
-
157
- final_user = (user_input or "").strip()
158
- if vision_context:
159
- final_user = f"{final_user}\n{vision_context}".strip()
160
-
161
- # Build final prompt with system + memory
162
- full_prompt = build_prompt(final_user, system_prompt)
163
-
164
- # Generate
165
- gen_text = generate_text(
166
- full_prompt,
167
- temperature=temperature,
168
- max_new_tokens=max_new_tokens,
169
- top_p=0.9,
170
- do_sample=True
171
- )
172
- assistant = extract_assistant_reply(gen_text, final_user)
173
 
174
- # Update memory
175
- if final_user:
176
- memory.append((final_user, assistant))
 
 
177
 
178
- return assistant
179
 
180
  # -----------------------------
181
- # Inference API adapter
182
  # -----------------------------
183
- # This lets you send either a simple string or a JSON object in "inputs".
184
- # If "inputs" is dict-like JSON, we extract 'mode', 'image', 'options', etc.
185
- def hf_api_predict(inputs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  try:
187
- # Case 1: inputs is already a dict-like object (Gradio may pass parsed JSON)
188
- payload = inputs if isinstance(inputs, dict) else None
189
-
190
- # Case 2: inputs is a string that might be JSON
191
- if payload is None and isinstance(inputs, str) and inputs.strip().startswith("{"):
192
- payload = json.loads(inputs)
193
-
194
- if payload is None:
195
- # Treat as plain prompt
196
- return handle_request(user_input=str(inputs))
197
-
198
- # Extract universal fields
199
- mode = payload.get("mode", "chat")
200
- system = payload.get("system", DEFAULT_SYSTEM_PROMPT)
201
- reset = bool(payload.get("reset", False))
202
- options = payload.get("options", {}) or {}
203
-
204
- # Inputs can be a string or object
205
- user_msg = payload.get("inputs", "")
206
- image_b64 = payload.get("image", None)
207
-
208
- temperature = float(options.get("temperature", 0.7))
209
- max_new_tokens = int(options.get("max_new_tokens", 256))
210
-
211
- # Run
212
- reply = handle_request(
213
- user_input=user_msg,
214
- image_input=image_b64,
215
- temperature=temperature,
216
- max_new_tokens=max_new_tokens,
217
- system_prompt=system,
218
- reset_memory=reset
219
- )
220
- return reply
221
  except Exception as e:
222
- return f"(Error parsing/processing request: {e})"
223
 
224
  # -----------------------------
225
- # Gradio UI
226
  # -----------------------------
227
- with gr.Blocks(title="UniversalAI — Text + Image, Memory, Controls") as demo:
228
- gr.Markdown("## UniversalAI Text + Image, Memory, Controls")
 
 
 
 
 
 
229
 
230
  with gr.Row():
231
- with gr.Column():
232
- sys_box = gr.Textbox(
233
- label="System prompt",
234
- value=DEFAULT_SYSTEM_PROMPT,
235
- lines=3
236
- )
237
- prompt_box = gr.Textbox(
238
- label="Your message",
239
- placeholder="Ask anything… (You can also attach an image)",
240
- lines=4
241
- )
242
- image_box = gr.Image(
243
- label="Optional image",
244
- type="pil"
245
  )
246
  with gr.Row():
247
- temp_slider = gr.Slider(
248
- minimum=0.1, maximum=1.2, value=0.7, step=0.05,
249
- label="Creativity (temperature)"
250
- )
251
- max_tokens_slider = gr.Slider(
252
- minimum=32, maximum=1024, value=256, step=16,
253
- label="Max new tokens"
254
  )
255
- reset_chk = gr.Checkbox(
256
- label="Reset memory before this message",
257
- value=False
258
- )
259
- submit_btn = gr.Button("Send", variant="primary")
260
- clear_btn = gr.Button("Clear memory", variant="secondary")
261
 
262
- with gr.Column():
263
- output_box = gr.Textbox(
264
- label="Assistant",
265
- lines=20
 
266
  )
267
-
268
- def ui_send(system, prompt, image, temp, max_new, reset):
269
- reply = handle_request(
270
- user_input=prompt or "",
271
- image_input=image,
272
- temperature=temp,
273
- max_new_tokens=int(max_new),
274
- system_prompt=system or DEFAULT_SYSTEM_PROMPT,
275
- reset_memory=bool(reset)
276
- )
277
- return reply
278
-
279
- def ui_clear():
280
- memory.clear()
281
- return "Memory cleared."
282
-
283
- submit_btn.click(
284
- fn=ui_send,
285
- inputs=[sys_box, prompt_box, image_box, temp_slider, max_tokens_slider, reset_chk],
286
- outputs=[output_box]
287
  )
 
 
 
 
 
 
 
 
 
 
288
 
289
- clear_btn.click(
290
- fn=ui_clear,
291
- inputs=[],
292
- outputs=[output_box]
 
293
  )
 
 
 
294
 
295
- # Expose a simple API endpoint for HF Inference API callers:
296
- # Map a Textbox "inputs" to our universal parser.
297
- # This keeps the official /models/<user>/<space> endpoint working with JSON too.
298
- api_in = gr.Textbox(label="API (inputs)", visible=False)
299
- api_out = gr.Textbox(label="API (outputs)", visible=False)
300
- demo.load(fn=lambda: "", inputs=None, outputs=None) # no-op to ensure Blocks initializes
301
- demo.add_api_route("/predict", hf_api_predict, inputs=api_in, outputs=api_out) # Gradio 4.x
302
 
303
  if __name__ == "__main__":
304
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import time
 
 
 
 
 
3
  import gradio as gr
4
 
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
 
 
 
 
 
7
 
8
  # -----------------------------
9
+ # Config
10
  # -----------------------------
11
+ DEFAULT_MODEL = os.getenv("TEXT_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
12
+ DEFAULT_SYSTEM = (
13
+ "You are UniversalAI — a concise, capable, adaptive assistant. "
14
+ "Answer clearly, use Markdown for structure, show code in fenced blocks. "
15
+ "Ask clarifying questions when needed. Keep answers tight but complete."
16
  )
17
+ DEFAULT_TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
18
+ DEFAULT_MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512"))
19
 
20
+ # Safety pad token for many GPT-like models
21
+ DEFAULT_PAD_TOKEN_ID = 50256
 
 
 
 
 
22
 
23
  # -----------------------------
24
+ # Load model
25
  # -----------------------------
26
+ torch.set_grad_enabled(False)
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL, use_fast=True, trust_remote_code=True)
29
+ if tokenizer.pad_token_id is None:
30
+ tokenizer.pad_token_id = DEFAULT_PAD_TOKEN_ID
31
+
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ DEFAULT_MODEL,
34
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
35
+ device_map="auto",
36
+ trust_remote_code=True
37
  )
38
 
39
+ # -----------------------------
40
+ # Prompt building (ChatML/INST Hybrid)
41
+ # -----------------------------
42
+ def build_prompt(system_prompt: str, history: list[tuple[str, str]], user_msg: str) -> str:
43
+ # history is list of (user, assistant)
44
+ sys = system_prompt.strip() if system_prompt else DEFAULT_SYSTEM
45
+ lines = [f"<<SYS>>\n{sys}\n<</SYS>>"]
46
+ for u, a in history:
47
+ u = (u or "").strip()
48
+ a = (a or "").strip()
49
+ if not u and not a:
50
+ continue
51
+ lines.append(f"[INST] {u} [/INST]\n{a}")
52
  lines.append(f"[INST] {user_msg.strip()} [/INST]\n")
53
  return "\n".join(lines)
54
 
55
  # -----------------------------
56
+ # Generation (streaming)
57
  # -----------------------------
58
+ def stream_generate(
59
+ prompt: str,
60
+ temperature: float,
61
+ max_new_tokens: int,
62
+ ):
63
+ inputs = tokenizer(prompt, return_tensors="pt")
64
+ for k in inputs:
65
+ inputs[k] = inputs[k].to(model.device)
66
+
67
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
68
+ gen_kwargs = dict(
69
+ **inputs,
70
+ streamer=streamer,
71
+ do_sample=True,
72
+ temperature=float(max(0.01, temperature)),
73
+ top_p=0.9,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  max_new_tokens=int(max_new_tokens),
75
+ repetition_penalty=1.05,
76
+ pad_token_id=tokenizer.pad_token_id,
 
77
  )
 
 
78
 
79
+ # Run generation in a background thread so we can yield tokens
80
+ import threading
81
+ thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
82
+ thread.start()
83
+
84
+ partial = ""
85
+ for new_text in streamer:
86
+ partial += new_text
87
+ yield partial
 
88
 
89
  # -----------------------------
90
+ # Slash commands
91
  # -----------------------------
92
+ def apply_slash_commands(user_msg: str, system_prompt: str, history: list[tuple[str, str]]):
93
+ msg = (user_msg or "").strip()
94
+ sys = system_prompt
95
+
96
+ if msg.lower().startswith("/reset"):
97
+ return "", sys, [], "Memory cleared."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ if msg.lower().startswith("/system:"):
100
+ new_sys = msg.split(":", 1)[1].strip()
101
+ if new_sys:
102
+ return "", new_sys, history, "System prompt updated."
103
+ return msg, sys, history, "No system text provided."
104
 
105
+ return msg, sys, history, None
106
 
107
  # -----------------------------
108
+ # Chat handlers
109
  # -----------------------------
110
+ def chat_submit(
111
+ user_msg, chat_history, system_prompt, temperature, max_new_tokens, last_user
112
+ ):
113
+ # Initialize states
114
+ chat_history = chat_history or []
115
+ last_user = ""
116
+
117
+ # Slash commands
118
+ processed_msg, new_system, new_history, note = apply_slash_commands(user_msg, system_prompt, chat_history)
119
+ if processed_msg == "" and note is not None:
120
+ # Command-only case: show system note
121
+ chat_history.append((user_msg, note))
122
+ return "", chat_history, new_system, last_user
123
+
124
+ # Build prompt
125
+ prompt = build_prompt(new_system, new_history, processed_msg)
126
+
127
+ # Add placeholder for streaming
128
+ new_history.append((processed_msg, ""))
129
+
130
+ # Start streaming
131
+ stream = stream_generate(prompt, temperature, max_new_tokens)
132
+ partial = ""
133
+ for chunk in stream:
134
+ partial = chunk
135
+ # Update the last assistant message
136
+ new_history[-1] = (processed_msg, partial)
137
+ yield "", new_history, new_system, processed_msg # keep last_user for regenerate
138
+
139
+ def regenerate(chat_history, system_prompt, temperature, max_new_tokens, last_user):
140
+ chat_history = chat_history or []
141
+ if not chat_history:
142
+ return chat_history
143
+ # last turn was assistant; rebuild by removing it and re-answering last_user
144
+ # Find last_user from state
145
+ user_msg = last_user or (chat_history[-1][0] if chat_history else "")
146
+ if not user_msg:
147
+ return chat_history
148
+
149
+ # Remove last assistant turn if it matches last_user
150
+ if chat_history and chat_history[-1][0] == user_msg:
151
+ chat_history.pop()
152
+
153
+ # Build prompt from remaining history
154
+ prompt = build_prompt(system_prompt, chat_history, user_msg)
155
+ chat_history.append((user_msg, ""))
156
+
157
+ stream = stream_generate(prompt, temperature, max_new_tokens)
158
+ partial = ""
159
+ for chunk in stream:
160
+ partial = chunk
161
+ chat_history[-1] = (user_msg, partial)
162
+ yield chat_history
163
+
164
+ def clear_memory():
165
+ return [], ""
166
+
167
+ # -----------------------------
168
+ # Inference API adapter (so /models/<user>/<space> works)
169
+ # Accepts either plain string or JSON:
170
+ # { "inputs": "...", "system": "...", "options": { "temperature": 0.7, "max_new_tokens": 256 }, "history": [...] }
171
+ # -----------------------------
172
+ def hf_inference_api(inputs):
173
  try:
174
+ # If inputs is dict-like, use it; else treat as plain prompt
175
+ if isinstance(inputs, dict):
176
+ prompt_text = inputs.get("inputs", "")
177
+ system = inputs.get("system", DEFAULT_SYSTEM)
178
+ options = inputs.get("options", {}) or {}
179
+ temp = float(options.get("temperature", DEFAULT_TEMPERATURE))
180
+ max_new = int(options.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS))
181
+ history = inputs.get("history", [])
182
+ else:
183
+ prompt_text = str(inputs)
184
+ system = DEFAULT_SYSTEM
185
+ temp = DEFAULT_TEMPERATURE
186
+ max_new = DEFAULT_MAX_NEW_TOKENS
187
+ history = []
188
+
189
+ prompt = build_prompt(system, history, prompt_text)
190
+ out = ""
191
+ for chunk in stream_generate(prompt, temp, max_new):
192
+ out = chunk
193
+ # Return final text
194
+ return out
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  except Exception as e:
196
+ return f"(Error: {e})"
197
 
198
  # -----------------------------
199
+ # UI (ChatGPT-like)
200
  # -----------------------------
201
+ CSS = """
202
+ :root { --radius: 14px; }
203
+ #chatbot { height: 70vh !important; }
204
+ .gradio-container { max-width: 1200px !important; margin: auto; }
205
+ """
206
+
207
+ with gr.Blocks(title="UniversalAI — ChatGPT‑style", css=CSS, theme=gr.themes.Soft()) as demo:
208
+ gr.Markdown("### UniversalAI — ChatGPT‑style")
209
 
210
  with gr.Row():
211
+ with gr.Column(scale=3):
212
+ chatbot = gr.Chatbot(
213
+ label="Chat",
214
+ bubble_full_width=False,
215
+ render_markdown=True,
216
+ likeable=True,
217
+ layout="bubble",
218
+ height=520,
219
+ elem_id="chatbot"
 
 
 
 
 
220
  )
221
  with gr.Row():
222
+ user_box = gr.Textbox(
223
+ placeholder="Message UniversalAI… (commands: /reset, /system: <prompt>)",
224
+ show_label=False,
225
+ lines=3
 
 
 
226
  )
227
+ with gr.Row():
228
+ send_btn = gr.Button("Send", variant="primary")
229
+ regen_btn = gr.Button("Regenerate", variant="secondary")
230
+ clear_btn = gr.Button("Clear", variant="secondary")
 
 
231
 
232
+ with gr.Column(scale=2):
233
+ sys_box = gr.Textbox(
234
+ value=DEFAULT_SYSTEM,
235
+ label="System prompt",
236
+ lines=6
237
  )
238
+ temp_slider = gr.Slider(
239
+ minimum=0.1, maximum=1.2, value=DEFAULT_TEMPERATURE, step=0.05,
240
+ label="Creativity (temperature)"
241
+ )
242
+ max_tokens = gr.Slider(
243
+ minimum=64, maximum=2048, value=DEFAULT_MAX_NEW_TOKENS, step=32,
244
+ label="Max new tokens"
245
+ )
246
+ gr.Markdown("> Tip: Use /reset to clear memory. Use /system: to change the assistant persona on the fly.")
247
+
248
+ # Session state
249
+ state_history = gr.State([]) # list[(user, assistant)]
250
+ state_last_user = gr.State("") # last user message for regenerate
251
+
252
+ # Wiring
253
+ send_evt = send_btn.click(
254
+ fn=chat_submit,
255
+ inputs=[user_box, state_history, sys_box, temp_slider, max_tokens, state_last_user],
256
+ outputs=[user_box, chatbot, sys_box, state_last_user],
257
+ queue=True
258
  )
259
+ send_evt.then(lambda h: h, inputs=chatbot, outputs=state_history)
260
+
261
+ # Allow Enter to send
262
+ enter_evt = user_box.submit(
263
+ fn=chat_submit,
264
+ inputs=[user_box, state_history, sys_box, temp_slider, max_tokens, state_last_user],
265
+ outputs=[user_box, chatbot, sys_box, state_last_user],
266
+ queue=True
267
+ )
268
+ enter_evt.then(lambda h: h, inputs=chatbot, outputs=state_history)
269
 
270
+ regen_stream = regen_btn.click(
271
+ fn=regenerate,
272
+ inputs=[state_history, sys_box, temp_slider, max_tokens, state_last_user],
273
+ outputs=[chatbot],
274
+ queue=True
275
  )
276
+ regen_stream.then(lambda h: h, inputs=chatbot, outputs=state_history)
277
+
278
+ clear_btn.click(fn=clear_memory, inputs=None, outputs=[chatbot, state_last_user])
279
 
280
+ # Expose a simple API route for Inference API callers
281
+ api_in = gr.Textbox(visible=False)
282
+ api_out = gr.Textbox(visible=False)
283
+ demo.add_api_route("/predict", hf_inference_api, inputs=api_in, outputs=api_out)
 
 
 
284
 
285
  if __name__ == "__main__":
286
+ demo.queue().launch()