File size: 3,970 Bytes
9774535
17cbe2a
 
 
 
 
 
9774535
 
 
17cbe2a
 
 
 
9774535
 
 
17cbe2a
 
 
 
 
6ff22fa
9774535
 
 
6ff22fa
 
9774535
6ff22fa
9774535
 
17cbe2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9774535
6ff22fa
17cbe2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9774535
17cbe2a
 
 
 
 
 
9774535
17cbe2a
 
 
 
 
6ff22fa
 
9774535
17cbe2a
9774535
 
 
 
6ff22fa
17cbe2a
 
 
 
 
 
 
9774535
6ff22fa
9774535
 
 
17cbe2a
9774535
17cbe2a
9774535
 
17cbe2a
 
9774535
 
17cbe2a
 
 
 
9774535
17cbe2a
9774535
 
17cbe2a
 
 
9774535
 
6ff22fa
 
17cbe2a
 
 
 
 
 
6ff22fa
17cbe2a
 
 
 
 
9774535
17cbe2a
 
 
9774535
17cbe2a
 
9774535
 
17cbe2a
6ff22fa
17cbe2a
9774535
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# app.py
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    StoppingCriteria,
    StoppingCriteriaList
)
import torch
import gradio as gr


# ======================
# Configuration
# ======================
MODEL_ID = "microsoft/Phi-3-mini-128k-instruct"


# ======================
# Load Model & Tokenizer
# ======================
print(f"🚀 Loading model: {MODEL_ID}")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=False,
    attn_implementation="eager"  # Use "flash_attention_2" if installed
)

print("✅ Model loaded successfully!")


# ======================
# Stopping Criteria
# ======================
class StopOnTokens(StoppingCriteria):
    def __init__(self, stop_token_ids):
        self.stop_token_ids = list(stop_token_ids)

    def __call__(self, input_ids, scores, **kwargs):
        for stop_id in self.stop_token_ids:
            if input_ids[0, -1] == stop_id:
                return True
        return False


# Get stop token IDs
stop_token_ids = [
    tokenizer.eos_token_id,  # Standard EOS
]
# Add <|end|> token if it exists
end_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
if isinstance(end_token_id, int) and end_token_id >= 0:
    stop_token_ids.append(end_token_id)

stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_token_ids)])


# ======================
# Response Function
# ======================
def respond(message: str, history):
    """
    Generate a response from the Phi-3 model.
    Args:
        message (str): New user input
        history (List[dict]): Chat history in {"role": ..., "content": ...} format
    Returns:
        str: The model's response (text only)
    """
    if not message.strip():
        return ""

    # Build conversation
    messages = history + [{"role": "user", "content": message}]

    # Apply Phi-3 chat template
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # Tokenize
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=128000
    ).to(model.device)
    print('Tokenized input: ', inputs)

    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            do_sample=True,
            temperature=0.1,
            top_p=0.9,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
            stopping_criteria=stopping_criteria,
        )

    # Decode only the new tokens (after input)
    new_tokens = outputs[0][inputs.input_ids.shape[1]:]
    response = tokenizer.decode(new_tokens, skip_special_tokens=True)
    print('Response: ', response)

    return response  # Gradio will auto-append to chat history


# ======================
# Gradio Interface
# ======================
demo = gr.ChatInterface(
    fn=respond,
    chatbot=gr.Chatbot(
        height=600,
        type="messages"  # Required for Gradio v5
    ),
    textbox=gr.Textbox(
        placeholder="Ask me anything about AI, science, coding, and more...",
        container=False,
        scale=7
    ),
    title="🧠 Phi-3 Mini (128K Context) Chat",
    description="""
    A demo of Microsoft's **Phi-3-mini-128k-instruct** model — a powerful small LLM with support for ultra-long context.
    Try asking it to summarize long texts, explain complex topics, or write code.
    """,
    examples=[
        "Who are you?",
        "Explain quantum entanglement simply.",
        "Write a Python function to detect cycles in a linked list."
    ],
    # Note: retry_btn, undo_btn, clear_btn removed — not supported in v5
    # Toolbar appears automatically
)

# ======================
# Launch
# ======================
if __name__ == "__main__":
    demo.launch()