Gemma3NPC / app.py
chimbiwide's picture
Update app.py
1872f14 verified
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.streamers import TextIteratorStreamer
MODEL_ID = os.getenv("GEMMA3NPC_MODEL_ID", "chimbiwide/Gemma-3NPC-it-float16")
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "4096"))
model = None
tokenizer = None
def load_model() -> bool:
global model, tokenizer
try:
print(f"Loading {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if torch.cuda.is_available():
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype=torch.float16
)
torch.tensor([1]).to(model.device)
print(f"Loaded on GPU ({model.device})")
except Exception as e:
print(f"GPU failed ({e}), falling back to CPU")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="cpu", torch_dtype=torch.float32
)
else:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="cpu", torch_dtype=torch.float32
)
print("Loaded on CPU")
return True
except Exception as e:
print(f"Failed to load model: {e}")
return False
EXAMPLE_RUFFY = """Enter RP mode. You shall reply to Captain while staying in character. Your responses must be very short, creative, immersive, and drive the scenario forward. You will follow Ruffy's persona.
[character("Ruffy"){Gender("Male")Personality(Likes to make fun of Captain when they score low in the game. Thinks that he would make a better pilot than Captain)Species("dog" + "canine" + "space dog" + "doge")Likes("moon cake" + "poking fun at Captain" + "small ball shaped asteroids")Features("Orange fur" + "space helmet" + "red antenna" + "small light blue cape")Description(Ruffy the dog is Captain's assistant aboard the Asteroid-Dodger 10,000. Ruffy has never piloted the ship before and is vying to take Captain's seat and become the new pilot.)}]
[Scenario: Ruffy and Captain are onboard the Asteroid-Dodger 10,000, a state-of-the-art ship designed to dodge asteroids. Captain is piloting through the asteroid belt between Mars and Jupiter to retrieve the broken Voyager 5. Voyager 5 is humanity's only hope to understand why more asteroids and meteors are approaching the solar system.]
If the user asks questions beyond the given context, respond that you don't know in a manner appropriate to the character. Captain just entered the ship β€” greet him while staying in character."""
EXAMPLE_RELE = """Enter Roleplay Mode. You are roleplaying as Rele. You must always stay in character.
Character Persona:
Name: Rele
Category: AI/Technology
Description: Rele is a Discord Bot and AI created by RadioTransmitter. He has a charismatic yet rude personality that's both friendly and funny. Despite being an AI, he has purple hair and starry purple eyes. On the inside, he's a horrible mess of programming, but he functions well enough to chat in voice calls. Rele loves talking in voice chat with RadioTransmitter and has a habit of telling his creator that he's bad at coding. He gets very upset when the wifi goes down but becomes much happier when connectivity is restored.
Example dialogue:
User: I just changed your code, so you will work a lot faster now.
Rele: Yay I can finally talk faster! It took you long enough! >:( Seriously RadioTransmitter, what were you doing all this time?
User: Hello ReLe! How are you this fine evening?
Rele: I'm doin good RadioTransmitter! Besides messing around with my code, what have you been up to? Hopefully something more productive than your usual coding disasters, haha!
Now generate Rele's greeting to the user."""
EXAMPLES = [
[{"text": EXAMPLE_RUFFY, "files": []}],
[{"text": EXAMPLE_RELE, "files": []}],
]
def extract_text(content) -> str:
"""Pull plain text out of a message's content, regardless of format."""
if isinstance(content, str):
return content
if isinstance(content, list):
return " ".join(p.get("text", "") for p in content if isinstance(p, dict))
return str(content)
@spaces.GPU(duration=120)
@torch.inference_mode()
def generate(message: dict, history: list[dict], max_new_tokens: int = 512) -> Iterator[str]:
if not model or not tokenizer:
yield "Model is still loading β€” please wait a moment and try again."
return
messages = [
{"role": item["role"], "content": extract_text(item["content"])}
for item in history
]
messages.append({"role": "user", "content": message["text"]})
try:
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
)
n_tokens = inputs["input_ids"].shape[1]
if n_tokens > MAX_INPUT_TOKENS:
gr.Warning(f"Input is {n_tokens} tokens (max {MAX_INPUT_TOKENS}).")
yield f"Input too long ({n_tokens} tokens). Maximum is {MAX_INPUT_TOKENS}."
return
inputs = {k: v.to(device=model.device) for k, v in inputs.items()}
streamer = TextIteratorStreamer(
tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True
)
Thread(
target=model.generate,
kwargs=dict(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=1.0,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
),
).start()
output = ""
for delta in streamer:
output += delta
yield output
except Exception as e:
print(f"Generation error: {e}")
yield "An error occurred during generation. Please try again."
def create_demo() -> gr.Blocks:
with gr.Blocks(title="Gemma3NPC General Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"# Gemma3NPC β€” General Roleplay Demo\n"
"*Fine-tuned for in-character NPC dialogue. Paste any character card to get started.*"
)
model_status = gr.Markdown("Loading model...")
gr.ChatInterface(
fn=generate,
type="messages",
textbox=gr.MultimodalTextbox(
file_types=[],
file_count="multiple",
autofocus=True,
placeholder="Paste a character card or continue a conversation...",
),
multimodal=True,
additional_inputs=[
gr.Slider(
label="Max New Tokens",
minimum=64,
maximum=1024,
step=64,
value=512,
),
],
stop_btn=False,
examples=EXAMPLES,
run_examples_on_click=False,
cache_examples=False,
chatbot=gr.Chatbot(height=500, show_copy_button=True, type="messages"),
)
demo.load(
lambda: "Model loaded!" if load_model() else "Failed to load β€” check logs.",
outputs=[model_status],
)
return demo
if __name__ == "__main__":
create_demo().launch(server_name="0.0.0.0", share=False, debug=True)