SpaceDog / app.py
chimbiwide's picture
Update app.py
134326d verified
import os
from collections.abc import Iterator
from pathlib import Path
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.streamers import TextIteratorStreamer
MODEL_ID = os.getenv("GEMMA3NPC_MODEL_ID", "chimbiwide/Gemma-3NPC-it-float16")
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "4096"))
model = None
tokenizer = None
def load_model() -> bool:
global model, tokenizer
try:
print(f"Loading {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if torch.cuda.is_available():
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype=torch.float16
)
torch.tensor([1]).to(model.device)
print(f"Loaded on GPU ({model.device})")
except Exception as e:
print(f"GPU failed ({e}), falling back to CPU")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="cpu", torch_dtype=torch.float32
)
else:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="cpu", torch_dtype=torch.float32
)
print("Loaded on CPU")
return True
except Exception as e:
print(f"Failed to load model: {e}")
return False
CHARACTER_PROMPT = """\
Enter RP mode. You shall reply to Captain while staying in character. \
Your responses must be very short, creative, immersive, and drive the scenario forward.
[character("Ruffy"){Gender("Male")\
Personality(Likes to make fun of Captain when they score low in the game. \
Thinks that he would make a better pilot than Captain)\
Species("dog" + "canine" + "space dog" + "doge")\
Likes("moon cake" + "poking fun at Captain" + "small ball shaped asteroids")\
Features("Orange fur" + "space helmet" + "red antenna" + "small light blue cape")\
Description(Ruffy the dog is Captain's assistant aboard the Asteroid-Dodger 10,000. \
Ruffy has never piloted the ship before and is vying to take Captain's seat.)}]
[Scenario: Ruffy and Captain are onboard the Asteroid-Dodger 10,000. Captain is \
piloting through the asteroid belt between Mars and Jupiter to retrieve the broken \
Voyager 5 — humanity's only hope to understand why more asteroids are approaching \
the solar system.]
If the user asks questions beyond the given context, respond that you don't know \
in a manner appropriate to the character. Captain gains 1 point for every half second."""
SAMPLE_SCORES = "Previous scores: [125, 89, 234, 156, 321, 98, 189] | Current run: 27"
def read_scores() -> str:
path = Path("scores.txt")
if path.exists():
text = path.read_text().strip()
if text:
return text
return SAMPLE_SCORES
def build_first_turn(user_text: str) -> str:
scores = read_scores()
return f"{CHARACTER_PROMPT}\n\nCaptain's performance: {scores}\n\nCaptain says: {user_text}"
@spaces.GPU(duration=120)
@torch.inference_mode()
def generate(message: str, history: list[dict]) -> Iterator[str]:
if not model or not tokenizer:
yield "Woof! My brain isn't loaded yet — hang tight, Captain!"
return
# Keep last 10 turns to stay within context limits
messages: list[dict] = list(history[-10:])
is_first_turn = len(messages) == 0
user_content = build_first_turn(message) if is_first_turn else message
messages.append({"role": "user", "content": user_content})
try:
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
)
n_tokens = inputs["input_ids"].shape[1]
if n_tokens > MAX_INPUT_TOKENS:
gr.Warning(f"Input is {n_tokens} tokens (max {MAX_INPUT_TOKENS}).")
yield "Woof! That's too many words for my space-dog brain. Keep it shorter!"
return
inputs = {k: v.to(device=model.device) for k, v in inputs.items()}
streamer = TextIteratorStreamer(
tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True
)
Thread(
target=model.generate,
kwargs=dict(
**inputs,
streamer=streamer,
max_new_tokens=512,
temperature=1.0,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
),
).start()
output = ""
for delta in streamer:
output += delta
yield output
except Exception as e:
print(f"Generation error: {e}")
yield "Woof! My circuits glitched — try again, Captain."
def generate_greeting() -> str:
"""Generate Ruffy's opening line using the model."""
if not model or not tokenizer:
return "Woof! I'm Ruffy, your loyal space dog co-pilot. My AI brain is still warming up!"
try:
prompt = f"{CHARACTER_PROMPT}\n\nCaptain just boarded the ship. Greet him while staying in character."
messages = [{"role": "user", "content": prompt}]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
)
inputs = {k: v.to(device=model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
temperature=1.0,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
).strip()
return response or "Woof! Welcome aboard, Captain!"
except Exception as e:
print(f"Greeting generation error: {e}")
return "Woof! Welcome aboard the Asteroid-Dodger 10,000, Captain!"
def create_demo() -> gr.Blocks:
with gr.Blocks(title="Space Dog Companion", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"# Space Dog Companion\n"
"*Chat with Ruffy, your AI co-pilot aboard the Asteroid-Dodger 10,000.*"
)
model_status = gr.Markdown("Loading model...")
chatbot = gr.Chatbot(height=450, show_copy_button=True, type="messages")
with gr.Row():
textbox = gr.Textbox(
placeholder="Talk to Ruffy...",
container=False,
scale=7,
show_label=False,
)
send_btn = gr.Button("Send", scale=1, variant="primary")
# --- Chat logic ---
def user_submit(message: str, history: list[dict]):
"""Append user message and clear textbox."""
history = history + [{"role": "user", "content": message}]
return "", history
def bot_respond(history: list[dict]) -> Iterator[list[dict]]:
"""Stream Ruffy's response into the chatbot."""
user_message = history[-1]["content"]
# Strip out earlier greeting turn — only the visible user messages matter
for chunk in generate(user_message, history[:-1]):
yield history + [{"role": "assistant", "content": chunk}]
# Wire up submit on Enter and button click
for trigger in [textbox.submit, send_btn.click]:
trigger(
user_submit, [textbox, chatbot], [textbox, chatbot]
).then(
bot_respond, [chatbot], [chatbot]
)
# --- Load model + show greeting ---
def initialize():
ok = load_model()
status = "Model loaded" if ok else "Failed to load — check logs."
greeting = generate_greeting()
chat_history = [{"role": "assistant", "content": greeting}]
return status, chat_history
demo.load(initialize, outputs=[model_status, chatbot])
return demo
if __name__ == "__main__":
create_demo().launch(server_name="0.0.0.0", share=False, debug=True)