richardprobe commited on
Commit
03723d8
·
verified ·
1 Parent(s): 8e6d217

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -34
app.py CHANGED
@@ -5,18 +5,14 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
  from peft import PeftModel
6
 
7
  # --- CONFIG ---
8
- BASE_MODEL = "microsoft/Phi-4-mini-instruct" # base
9
  ADAPTER_REPO = "richardprobe/phi4-mini-chris-assistant-richard-adapter"
10
  SYSTEM_PROMPT = "You are Richard. Be concise and casual."
11
-
12
- # Use 4-bit quantization for smaller GPU Spaces
13
  LOAD_4BIT = True
14
 
15
-
16
  def load_model():
17
  print("Loading tokenizer...")
18
  tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
19
-
20
  print("Loading base model...")
21
  kwargs = dict(device_map="auto")
22
  if LOAD_4BIT:
@@ -33,65 +29,95 @@ def load_model():
33
  base = AutoModelForCausalLM.from_pretrained(BASE_MODEL, **kwargs)
34
 
35
  print("Loading adapter...")
 
36
  model = PeftModel.from_pretrained(base, ADAPTER_REPO, use_auth_token=os.getenv("HF_TOKEN"))
37
  model.eval()
38
- return tok, model
39
 
 
 
 
40
 
41
- tok, model = load_model()
42
 
 
43
 
44
- def chat_generate(history, temperature=0.7, top_p=0.95, max_new_tokens=256, repetition_penalty=1.1):
45
- """
46
- history: list[(user, assistant)] from gr.ChatInterface
47
- Returns: assistant reply as a string
48
- """
49
- messages = []
50
  if SYSTEM_PROMPT:
51
- messages.append({"role": "system", "content": SYSTEM_PROMPT})
52
- for user, assistant in history:
53
- if user:
54
- messages.append({"role": "user", "content": user})
55
- if assistant:
56
- messages.append({"role": "assistant", "content": assistant})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  inputs = tok.apply_chat_template(
59
- messages,
60
- add_generation_prompt=True,
61
- return_tensors="pt"
62
  ).to(model.device)
63
 
64
  gen_kwargs = dict(
65
  max_new_tokens=int(max_new_tokens),
66
  temperature=float(temperature),
67
  top_p=float(top_p),
68
- do_sample=(temperature > 0),
69
  repetition_penalty=float(repetition_penalty),
70
  eos_token_id=tok.eos_token_id,
71
- pad_token_id=tok.eos_token_id,
72
  )
73
 
74
- with torch.inference_mode(), torch.cuda.amp.autocast(enabled=torch.cuda.is_available(), dtype=torch.bfloat16):
75
- output = model.generate(inputs, **gen_kwargs)
 
76
 
77
- gen_tokens = output[0][inputs.shape[-1]:]
78
  text = tok.decode(gen_tokens, skip_special_tokens=True, errors="ignore")
79
  return text.strip()
80
 
81
-
82
  demo = gr.ChatInterface(
83
  fn=chat_generate,
84
  title="Phi-4 Mini + LoRA Adapter (Chris style)",
85
  description="Base: microsoft/Phi-4-mini-instruct + your LoRA adapter. Style-tuned chat.",
 
 
 
 
 
 
 
86
  examples=[
87
  ["What are you up to?", 0.7, 0.95, 256, 1.1],
88
- ["You coming?", 0.7, 0.95, 256, 1.1],
89
- ["I'm on the can", 0.7, 0.95, 256, 1.1],
90
  ],
91
- cache_examples=True # (optional)
92
-
93
  )
94
 
95
-
96
  if __name__ == "__main__":
97
- demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
 
 
5
  from peft import PeftModel
6
 
7
  # --- CONFIG ---
8
+ BASE_MODEL = "microsoft/Phi-4-mini-instruct"
9
  ADAPTER_REPO = "richardprobe/phi4-mini-chris-assistant-richard-adapter"
10
  SYSTEM_PROMPT = "You are Richard. Be concise and casual."
 
 
11
  LOAD_4BIT = True
12
 
 
13
  def load_model():
14
  print("Loading tokenizer...")
15
  tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
 
16
  print("Loading base model...")
17
  kwargs = dict(device_map="auto")
18
  if LOAD_4BIT:
 
29
  base = AutoModelForCausalLM.from_pretrained(BASE_MODEL, **kwargs)
30
 
31
  print("Loading adapter...")
32
+ # HF Hub auth if needed
33
  model = PeftModel.from_pretrained(base, ADAPTER_REPO, use_auth_token=os.getenv("HF_TOKEN"))
34
  model.eval()
 
35
 
36
+ # make sure pad token exists
37
+ if tok.pad_token_id is None:
38
+ tok.pad_token = tok.eos_token
39
 
40
+ return tok, model
41
 
42
+ tok, model = load_model()
43
 
44
+ def _normalize_history(history):
45
+ """Accepts either tuples [(u,a), ...] or messages-style [{'role','content'}, ...]."""
46
+ msgs = []
 
 
 
47
  if SYSTEM_PROMPT:
48
+ msgs.append({"role": "system", "content": SYSTEM_PROMPT})
49
+
50
+ if not history:
51
+ return msgs
52
+
53
+ # messages-style
54
+ if isinstance(history[0], dict):
55
+ for m in history:
56
+ role = m.get("role")
57
+ content = m.get("content", "")
58
+ if isinstance(content, list): # v5 can send [{"type":"text","text":"..."}]
59
+ content = "".join(
60
+ c.get("text", "") if isinstance(c, dict) else str(c) for c in content
61
+ )
62
+ if role in {"user", "assistant", "system"}:
63
+ msgs.append({"role": role, "content": content})
64
+ else:
65
+ # tuples-style
66
+ for u, a in history:
67
+ if u:
68
+ msgs.append({"role": "user", "content": u})
69
+ if a:
70
+ msgs.append({"role": "assistant", "content": a})
71
+ return msgs
72
+
73
+ def chat_generate(message, history, temperature=0.7, top_p=0.95, max_new_tokens=256, repetition_penalty=1.1):
74
+ # Build messages
75
+ messages = _normalize_history(history)
76
+ if message:
77
+ messages.append({"role": "user", "content": message})
78
 
79
  inputs = tok.apply_chat_template(
80
+ messages, add_generation_prompt=True, return_tensors="pt"
 
 
81
  ).to(model.device)
82
 
83
  gen_kwargs = dict(
84
  max_new_tokens=int(max_new_tokens),
85
  temperature=float(temperature),
86
  top_p=float(top_p),
87
+ do_sample=float(temperature) > 0,
88
  repetition_penalty=float(repetition_penalty),
89
  eos_token_id=tok.eos_token_id,
90
+ pad_token_id=tok.pad_token_id,
91
  )
92
 
93
+ with torch.inference_mode():
94
+ with torch.cuda.amp.autocast(enabled=torch.cuda.is_available(), dtype=torch.bfloat16):
95
+ out = model.generate(inputs, **gen_kwargs)
96
 
97
+ gen_tokens = out[0][inputs.shape[-1]:]
98
  text = tok.decode(gen_tokens, skip_special_tokens=True, errors="ignore")
99
  return text.strip()
100
 
 
101
  demo = gr.ChatInterface(
102
  fn=chat_generate,
103
  title="Phi-4 Mini + LoRA Adapter (Chris style)",
104
  description="Base: microsoft/Phi-4-mini-instruct + your LoRA adapter. Style-tuned chat.",
105
+ additional_inputs=[
106
+ gr.Slider(0.0, 1.5, value=0.7, step=0.1, label="Temperature"),
107
+ gr.Slider(0.5, 1.0, value=0.95, step=0.01, label="Top-p"),
108
+ gr.Slider(16, 512, value=256, step=16, label="Max new tokens"),
109
+ gr.Slider(1.0, 1.5, value=1.1, step=0.05, label="Repetition penalty"),
110
+ ],
111
+ # Each example is: [message, *additional_inputs]
112
  examples=[
113
  ["What are you up to?", 0.7, 0.95, 256, 1.1],
114
+ ["You coming?", 0.7, 0.95, 256, 1.1],
115
+ ["I'm on the can", 0.7, 0.95, 256, 1.1],
116
  ],
117
+ cache_examples=False, # turn off while debugging; turn on later if you want
 
118
  )
119
 
 
120
  if __name__ == "__main__":
121
+ demo.queue(concurrency_count=1, max_size=8)
122
+ # Hide API docs to avoid the schema crash toast
123
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False, show_error=True)