Spaces:
Sleeping
Sleeping
File size: 3,970 Bytes
9774535 17cbe2a 9774535 17cbe2a 9774535 17cbe2a 6ff22fa 9774535 6ff22fa 9774535 6ff22fa 9774535 17cbe2a 9774535 6ff22fa 17cbe2a 9774535 17cbe2a 9774535 17cbe2a 6ff22fa 9774535 17cbe2a 9774535 6ff22fa 17cbe2a 9774535 6ff22fa 9774535 17cbe2a 9774535 17cbe2a 9774535 17cbe2a 9774535 17cbe2a 9774535 17cbe2a 9774535 17cbe2a 9774535 6ff22fa 17cbe2a 6ff22fa 17cbe2a 9774535 17cbe2a 9774535 17cbe2a 9774535 17cbe2a 6ff22fa 17cbe2a 9774535 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# app.py
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList
)
import torch
import gradio as gr
# ======================
# Configuration
# ======================
MODEL_ID = "microsoft/Phi-3-mini-128k-instruct"
# ======================
# Load Model & Tokenizer
# ======================
print(f"🚀 Loading model: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=False,
attn_implementation="eager" # Use "flash_attention_2" if installed
)
print("✅ Model loaded successfully!")
# ======================
# Stopping Criteria
# ======================
class StopOnTokens(StoppingCriteria):
def __init__(self, stop_token_ids):
self.stop_token_ids = list(stop_token_ids)
def __call__(self, input_ids, scores, **kwargs):
for stop_id in self.stop_token_ids:
if input_ids[0, -1] == stop_id:
return True
return False
# Get stop token IDs
stop_token_ids = [
tokenizer.eos_token_id, # Standard EOS
]
# Add <|end|> token if it exists
end_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
if isinstance(end_token_id, int) and end_token_id >= 0:
stop_token_ids.append(end_token_id)
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_token_ids)])
# ======================
# Response Function
# ======================
def respond(message: str, history):
"""
Generate a response from the Phi-3 model.
Args:
message (str): New user input
history (List[dict]): Chat history in {"role": ..., "content": ...} format
Returns:
str: The model's response (text only)
"""
if not message.strip():
return ""
# Build conversation
messages = history + [{"role": "user", "content": message}]
# Apply Phi-3 chat template
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=128000
).to(model.device)
print('Tokenized input: ', inputs)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.1,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
stopping_criteria=stopping_criteria,
)
# Decode only the new tokens (after input)
new_tokens = outputs[0][inputs.input_ids.shape[1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
print('Response: ', response)
return response # Gradio will auto-append to chat history
# ======================
# Gradio Interface
# ======================
demo = gr.ChatInterface(
fn=respond,
chatbot=gr.Chatbot(
height=600,
type="messages" # Required for Gradio v5
),
textbox=gr.Textbox(
placeholder="Ask me anything about AI, science, coding, and more...",
container=False,
scale=7
),
title="🧠 Phi-3 Mini (128K Context) Chat",
description="""
A demo of Microsoft's **Phi-3-mini-128k-instruct** model — a powerful small LLM with support for ultra-long context.
Try asking it to summarize long texts, explain complex topics, or write code.
""",
examples=[
"Who are you?",
"Explain quantum entanglement simply.",
"Write a Python function to detect cycles in a linked list."
],
# Note: retry_btn, undo_btn, clear_btn removed — not supported in v5
# Toolbar appears automatically
)
# ======================
# Launch
# ======================
if __name__ == "__main__":
demo.launch()
|