thisistheend / app.py
DarkMindForever's picture
Update app.py
9b20c73 verified
import os
import threading
from typing import List, Tuple, Dict, Any
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from huggingface_hub import login
# --- 1. AUTHENTICATION ---
# Ensure HF_TOKEN is set in your Space Secrets
HF_TOKEN = os.environ.get("HF_TOKEN", None)
if HF_TOKEN:
login(token=HF_TOKEN)
# --- 2. MODEL SETUP ---
MODEL_ID = "openai-community/gpt2-xl"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
# Using 'dtype' instead of 'torch_dtype' per your earlier warning
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True
)
# --- 3. CHAT LOGIC ---
@spaces.GPU
def chat(message: str, history: List[Any]):
"""
Handles conversation history flexibly for any Gradio version.
"""
prompt = ""
for entry in history:
# Check if entry is Gradio 5+ Dictionary format
if isinstance(entry, dict):
role = entry.get("role")
content = entry.get("content")
if role == "user":
prompt += f"<start_of_turn>user\n{content}<end_of_turn>\n"
else:
prompt += f"<start_of_turn>model\n{content}<end_of_turn>\n"
# Check if entry is Gradio 4- Tuple format
elif isinstance(entry, (list, tuple)):
user_msg, bot_msg = entry
prompt += f"<start_of_turn>user\n{user_msg}<end_of_turn>\n"
prompt += f"<start_of_turn>model\n{bot_msg}<end_of_turn>\n"
# Add current user message
prompt += f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
temperature=0.7,
top_p=0.95,
)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
# --- 4. GRADIO UI ---
# Keep Blocks constructor empty to avoid parameter migration errors
with gr.Blocks() as demo:
gr.Markdown("# 💎 Google Gemma 2 Chat")
gr.Markdown("Zero-training implementation optimized for your environment.")
gr.ChatInterface(
fn=chat,
examples=["Tell me a fun fact about space.", "Write a short email to a client."]
)
if __name__ == "__main__":
# Theme is moved here to satisfy the Gradio 6.0 warning
demo.launch(theme=gr.themes.Soft())