thava commited on
Commit
17cbe2a
Β·
1 Parent(s): 6ff22fa

Updates to fix errors

Browse files
Files changed (1) hide show
  1. app.py +95 -33
app.py CHANGED
@@ -1,14 +1,25 @@
1
  # app.py
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
3
  import torch
4
  import gradio as gr
5
 
6
- # Model ID
 
 
 
7
  MODEL_ID = "microsoft/Phi-3-mini-128k-instruct"
8
 
9
- print(f"Loading model: {MODEL_ID}")
10
 
11
- # Load tokenizer and model
 
 
 
 
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13
 
14
  model = AutoModelForCausalLM.from_pretrained(
@@ -19,69 +30,120 @@ model = AutoModelForCausalLM.from_pretrained(
19
  attn_implementation="eager" # Use "flash_attention_2" if installed
20
  )
21
 
22
- print("Model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
- def respond(message, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  """
27
- message: string (new user input)
28
- history: list of dicts with 'role' and 'content' (Gradio v5 'messages' format)
 
 
 
 
29
  """
30
- # Build conversation history
31
- full_conversation = history + [{"role": "user", "content": message}]
 
 
 
32
 
33
  # Apply Phi-3 chat template
34
  prompt = tokenizer.apply_chat_template(
35
- full_conversation,
36
  tokenize=False,
37
  add_generation_prompt=True
38
  )
39
 
40
  # Tokenize
41
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128000).to(model.device)
 
 
 
 
 
 
42
 
43
  # Generate
44
  with torch.no_grad():
45
  outputs = model.generate(
46
  **inputs,
47
- max_new_tokens=1024,
48
  do_sample=True,
49
- temperature=0.7,
50
  top_p=0.9,
51
  eos_token_id=tokenizer.eos_token_id,
52
- pad_token_id=tokenizer.eos_token_id
 
53
  )
54
 
55
- # Decode only the new part
56
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
57
-
58
- # Extract assistant's response after the last <|assistant|>
59
- assistant_start = prompt.rfind("<|assistant|>") + len("<|assistant|>")
60
- response = full_response[assistant_start:].strip()
61
 
62
- # Return updated history
63
- return response, full_conversation + [{"role": "assistant", "content": response}]
64
 
65
 
66
- # Create Gradio ChatInterface
 
 
67
  demo = gr.ChatInterface(
68
  fn=respond,
69
  chatbot=gr.Chatbot(
70
  height=600,
71
- type="messages" # βœ… Required in Gradio 5
 
 
 
 
 
72
  ),
73
- textbox=gr.Textbox(placeholder="Ask me anything...", container=False, scale=7),
74
- title="🧠 Phi-3 Mini (128K) Chat",
75
- description="A demo of Microsoft's Phi-3-mini-128k-instruct model with long-context support.",
 
 
76
  examples=[
77
- "Summarize a long article.",
78
- "Explain how black holes work.",
79
- "Write a Python function to reverse a linked list."
80
  ],
81
- # ❌ Removed: retry_btn, undo_btn, clear_btn β€” not supported in v5
82
- # Use built-in toolbar instead
83
  )
84
 
 
85
  # Launch
 
86
  if __name__ == "__main__":
87
  demo.launch()
 
1
  # app.py
2
+ from transformers import (
3
+ AutoModelForCausalLM,
4
+ AutoTokenizer,
5
+ StoppingCriteria,
6
+ StoppingCriteriaList
7
+ )
8
  import torch
9
  import gradio as gr
10
 
11
+
12
+ # ======================
13
+ # Configuration
14
+ # ======================
15
  MODEL_ID = "microsoft/Phi-3-mini-128k-instruct"
16
 
 
17
 
18
+ # ======================
19
+ # Load Model & Tokenizer
20
+ # ======================
21
+ print(f"πŸš€ Loading model: {MODEL_ID}")
22
+
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
24
 
25
  model = AutoModelForCausalLM.from_pretrained(
 
30
  attn_implementation="eager" # Use "flash_attention_2" if installed
31
  )
32
 
33
+ print("βœ… Model loaded successfully!")
34
+
35
+
36
+ # ======================
37
+ # Stopping Criteria
38
+ # ======================
39
+ class StopOnTokens(StoppingCriteria):
40
+ def __init__(self, stop_token_ids):
41
+ self.stop_token_ids = list(stop_token_ids)
42
+
43
+ def __call__(self, input_ids, scores, **kwargs):
44
+ for stop_id in self.stop_token_ids:
45
+ if input_ids[0, -1] == stop_id:
46
+ return True
47
+ return False
48
 
49
 
50
+ # Get stop token IDs
51
+ stop_token_ids = [
52
+ tokenizer.eos_token_id, # Standard EOS
53
+ ]
54
+ # Add <|end|> token if it exists
55
+ end_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
56
+ if isinstance(end_token_id, int) and end_token_id >= 0:
57
+ stop_token_ids.append(end_token_id)
58
+
59
+ stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_token_ids)])
60
+
61
+
62
+ # ======================
63
+ # Response Function
64
+ # ======================
65
+ def respond(message: str, history):
66
  """
67
+ Generate a response from the Phi-3 model.
68
+ Args:
69
+ message (str): New user input
70
+ history (List[dict]): Chat history in {"role": ..., "content": ...} format
71
+ Returns:
72
+ str: The model's response (text only)
73
  """
74
+ if not message.strip():
75
+ return ""
76
+
77
+ # Build conversation
78
+ messages = history + [{"role": "user", "content": message}]
79
 
80
  # Apply Phi-3 chat template
81
  prompt = tokenizer.apply_chat_template(
82
+ messages,
83
  tokenize=False,
84
  add_generation_prompt=True
85
  )
86
 
87
  # Tokenize
88
+ inputs = tokenizer(
89
+ prompt,
90
+ return_tensors="pt",
91
+ truncation=True,
92
+ max_length=128000
93
+ ).to(model.device)
94
+ print('Tokenized input: ', inputs)
95
 
96
  # Generate
97
  with torch.no_grad():
98
  outputs = model.generate(
99
  **inputs,
100
+ max_new_tokens=256,
101
  do_sample=True,
102
+ temperature=0.1,
103
  top_p=0.9,
104
  eos_token_id=tokenizer.eos_token_id,
105
+ pad_token_id=tokenizer.eos_token_id,
106
+ stopping_criteria=stopping_criteria,
107
  )
108
 
109
+ # Decode only the new tokens (after input)
110
+ new_tokens = outputs[0][inputs.input_ids.shape[1]:]
111
+ response = tokenizer.decode(new_tokens, skip_special_tokens=True)
112
+ print('Response: ', response)
 
 
113
 
114
+ return response # Gradio will auto-append to chat history
 
115
 
116
 
117
+ # ======================
118
+ # Gradio Interface
119
+ # ======================
120
  demo = gr.ChatInterface(
121
  fn=respond,
122
  chatbot=gr.Chatbot(
123
  height=600,
124
+ type="messages" # Required for Gradio v5
125
+ ),
126
+ textbox=gr.Textbox(
127
+ placeholder="Ask me anything about AI, science, coding, and more...",
128
+ container=False,
129
+ scale=7
130
  ),
131
+ title="🧠 Phi-3 Mini (128K Context) Chat",
132
+ description="""
133
+ A demo of Microsoft's **Phi-3-mini-128k-instruct** model β€” a powerful small LLM with support for ultra-long context.
134
+ Try asking it to summarize long texts, explain complex topics, or write code.
135
+ """,
136
  examples=[
137
+ "Who are you?",
138
+ "Explain quantum entanglement simply.",
139
+ "Write a Python function to detect cycles in a linked list."
140
  ],
141
+ # Note: retry_btn, undo_btn, clear_btn removed β€” not supported in v5
142
+ # Toolbar appears automatically
143
  )
144
 
145
+ # ======================
146
  # Launch
147
+ # ======================
148
  if __name__ == "__main__":
149
  demo.launch()