phi3-mini-demo / app.py
thava's picture
Updates to fix errors
17cbe2a
# app.py
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList
)
import torch
import gradio as gr
# ======================
# Configuration
# ======================
MODEL_ID = "microsoft/Phi-3-mini-128k-instruct"
# ======================
# Load Model & Tokenizer
# ======================
print(f"πŸš€ Loading model: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=False,
attn_implementation="eager" # Use "flash_attention_2" if installed
)
print("βœ… Model loaded successfully!")
# ======================
# Stopping Criteria
# ======================
class StopOnTokens(StoppingCriteria):
def __init__(self, stop_token_ids):
self.stop_token_ids = list(stop_token_ids)
def __call__(self, input_ids, scores, **kwargs):
for stop_id in self.stop_token_ids:
if input_ids[0, -1] == stop_id:
return True
return False
# Get stop token IDs
stop_token_ids = [
tokenizer.eos_token_id, # Standard EOS
]
# Add <|end|> token if it exists
end_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
if isinstance(end_token_id, int) and end_token_id >= 0:
stop_token_ids.append(end_token_id)
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_token_ids)])
# ======================
# Response Function
# ======================
def respond(message: str, history):
"""
Generate a response from the Phi-3 model.
Args:
message (str): New user input
history (List[dict]): Chat history in {"role": ..., "content": ...} format
Returns:
str: The model's response (text only)
"""
if not message.strip():
return ""
# Build conversation
messages = history + [{"role": "user", "content": message}]
# Apply Phi-3 chat template
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=128000
).to(model.device)
print('Tokenized input: ', inputs)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.1,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
stopping_criteria=stopping_criteria,
)
# Decode only the new tokens (after input)
new_tokens = outputs[0][inputs.input_ids.shape[1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
print('Response: ', response)
return response # Gradio will auto-append to chat history
# ======================
# Gradio Interface
# ======================
demo = gr.ChatInterface(
fn=respond,
chatbot=gr.Chatbot(
height=600,
type="messages" # Required for Gradio v5
),
textbox=gr.Textbox(
placeholder="Ask me anything about AI, science, coding, and more...",
container=False,
scale=7
),
title="🧠 Phi-3 Mini (128K Context) Chat",
description="""
A demo of Microsoft's **Phi-3-mini-128k-instruct** model β€” a powerful small LLM with support for ultra-long context.
Try asking it to summarize long texts, explain complex topics, or write code.
""",
examples=[
"Who are you?",
"Explain quantum entanglement simply.",
"Write a Python function to detect cycles in a linked list."
],
# Note: retry_btn, undo_btn, clear_btn removed β€” not supported in v5
# Toolbar appears automatically
)
# ======================
# Launch
# ======================
if __name__ == "__main__":
demo.launch()