Spaces:
Sleeping
Sleeping
| import os | |
| import threading | |
| from typing import List, Tuple, Dict, Any | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| from huggingface_hub import login | |
| # --- 1. AUTHENTICATION --- | |
| # Ensure HF_TOKEN is set in your Space Secrets | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| if HF_TOKEN: | |
| login(token=HF_TOKEN) | |
| # --- 2. MODEL SETUP --- | |
| MODEL_ID = "openai-community/gpt2-xl" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| device_map="auto", | |
| # Using 'dtype' instead of 'torch_dtype' per your earlier warning | |
| dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| low_cpu_mem_usage=True | |
| ) | |
| # --- 3. CHAT LOGIC --- | |
| def chat(message: str, history: List[Any]): | |
| """ | |
| Handles conversation history flexibly for any Gradio version. | |
| """ | |
| prompt = "" | |
| for entry in history: | |
| # Check if entry is Gradio 5+ Dictionary format | |
| if isinstance(entry, dict): | |
| role = entry.get("role") | |
| content = entry.get("content") | |
| if role == "user": | |
| prompt += f"<start_of_turn>user\n{content}<end_of_turn>\n" | |
| else: | |
| prompt += f"<start_of_turn>model\n{content}<end_of_turn>\n" | |
| # Check if entry is Gradio 4- Tuple format | |
| elif isinstance(entry, (list, tuple)): | |
| user_msg, bot_msg = entry | |
| prompt += f"<start_of_turn>user\n{user_msg}<end_of_turn>\n" | |
| prompt += f"<start_of_turn>model\n{bot_msg}<end_of_turn>\n" | |
| # Add current user message | |
| prompt += f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n" | |
| inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=1024, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.95, | |
| ) | |
| thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) | |
| thread.start() | |
| partial_text = "" | |
| for new_text in streamer: | |
| partial_text += new_text | |
| yield partial_text | |
| # --- 4. GRADIO UI --- | |
| # Keep Blocks constructor empty to avoid parameter migration errors | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 💎 Google Gemma 2 Chat") | |
| gr.Markdown("Zero-training implementation optimized for your environment.") | |
| gr.ChatInterface( | |
| fn=chat, | |
| examples=["Tell me a fun fact about space.", "Write a short email to a client."] | |
| ) | |
| if __name__ == "__main__": | |
| # Theme is moved here to satisfy the Gradio 6.0 warning | |
| demo.launch(theme=gr.themes.Soft()) |