import os import threading from typing import List, Tuple, Dict, Any import torch import gradio as gr import spaces from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from huggingface_hub import login # --- 1. AUTHENTICATION --- # Ensure HF_TOKEN is set in your Space Secrets HF_TOKEN = os.environ.get("HF_TOKEN", None) if HF_TOKEN: login(token=HF_TOKEN) # --- 2. MODEL SETUP --- MODEL_ID = "openai-community/gpt2-xl" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", # Using 'dtype' instead of 'torch_dtype' per your earlier warning dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True ) # --- 3. CHAT LOGIC --- @spaces.GPU def chat(message: str, history: List[Any]): """ Handles conversation history flexibly for any Gradio version. """ prompt = "" for entry in history: # Check if entry is Gradio 5+ Dictionary format if isinstance(entry, dict): role = entry.get("role") content = entry.get("content") if role == "user": prompt += f"user\n{content}\n" else: prompt += f"model\n{content}\n" # Check if entry is Gradio 4- Tuple format elif isinstance(entry, (list, tuple)): user_msg, bot_msg = entry prompt += f"user\n{user_msg}\n" prompt += f"model\n{bot_msg}\n" # Add current user message prompt += f"user\n{message}\nmodel\n" inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=1024, do_sample=True, temperature=0.7, top_p=0.95, ) thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) thread.start() partial_text = "" for new_text in streamer: partial_text += new_text yield partial_text # --- 4. GRADIO UI --- # Keep Blocks constructor empty to avoid parameter migration errors with gr.Blocks() as demo: gr.Markdown("# 💎 Google Gemma 2 Chat") gr.Markdown("Zero-training implementation optimized for your environment.") gr.ChatInterface( fn=chat, examples=["Tell me a fun fact about space.", "Write a short email to a client."] ) if __name__ == "__main__": # Theme is moved here to satisfy the Gradio 6.0 warning demo.launch(theme=gr.themes.Soft())