heffnt commited on
Commit
78f4f84
·
1 Parent(s): 109a454

Refactor model selection in app.py; replace single model variables with lists for local and API models, and update the model selection mechanism in the chatbot interface.

Browse files
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -3,10 +3,17 @@ from huggingface_hub import InferenceClient
3
  import os
4
 
5
  # Configuration
6
- LOCAL_MODEL = "jakeboggs/MTG-Llama" # "microsoft/Phi-3-mini-4k-instruct"
7
- API_MODEL = "openai/gpt-oss-20b"
8
  DEFAULT_SYSTEM_MESSAGE = "You are an expert assistant for Magic: The Gathering. You're name is Smart Confidant but people tend to call you Bob."
9
 
 
 
 
 
 
 
 
10
  pipe = None
11
  stop_inference = False
12
 
@@ -58,7 +65,7 @@ def respond(
58
  temperature,
59
  top_p,
60
  hf_token: gr.OAuthToken,
61
- use_local_model: bool,
62
  ):
63
  global pipe
64
 
@@ -67,14 +74,18 @@ def respond(
67
  messages.extend(history)
68
  messages.append({"role": "user", "content": message})
69
 
 
 
 
 
70
  response = ""
71
 
72
- if use_local_model:
73
- print("[MODE] local")
74
  from transformers import pipeline
75
  import torch
76
- if pipe is None:
77
- pipe = pipeline("text-generation", model=LOCAL_MODEL)
78
 
79
  # Build prompt as plain text
80
  prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
@@ -91,13 +102,13 @@ def respond(
91
  yield response.strip()
92
 
93
  else:
94
- print("[MODE] api")
95
 
96
  if hf_token is None or not getattr(hf_token, "token", None):
97
  yield "⚠️ Please log in with your Hugging Face account first."
98
  return
99
 
100
- client = InferenceClient(token=hf_token.token, model=API_MODEL)
101
 
102
  for chunk in client.chat_completion(
103
  messages,
@@ -121,7 +132,7 @@ chatbot = gr.ChatInterface(
121
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
122
  gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
123
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
124
- gr.Checkbox(label=f"Use Local Model (API model is {API_MODEL} and local model is {LOCAL_MODEL})", value=False),
125
  ],
126
  type="messages",
127
  )
 
3
  import os
4
 
5
  # Configuration
6
+ LOCAL_MODELS = ["jakeboggs/MTG-Llama", "microsoft/Phi-3-mini-4k-instruct"]
7
+ API_MODELS = ["openai/gpt-oss-20b", "meta-llama/Meta-Llama-3-8B-Instruct"]
8
  DEFAULT_SYSTEM_MESSAGE = "You are an expert assistant for Magic: The Gathering. You're name is Smart Confidant but people tend to call you Bob."
9
 
10
+ # Create model options with labels
11
+ MODEL_OPTIONS = []
12
+ for model in LOCAL_MODELS:
13
+ MODEL_OPTIONS.append(f"{model} (local)")
14
+ for model in API_MODELS:
15
+ MODEL_OPTIONS.append(f"{model} (api)")
16
+
17
  pipe = None
18
  stop_inference = False
19
 
 
65
  temperature,
66
  top_p,
67
  hf_token: gr.OAuthToken,
68
+ selected_model: str,
69
  ):
70
  global pipe
71
 
 
74
  messages.extend(history)
75
  messages.append({"role": "user", "content": message})
76
 
77
+ # Determine if model is local or API and extract model name
78
+ is_local = selected_model.endswith("(local)")
79
+ model_name = selected_model.replace(" (local)", "").replace(" (api)", "")
80
+
81
  response = ""
82
 
83
+ if is_local:
84
+ print(f"[MODE] local - {model_name}")
85
  from transformers import pipeline
86
  import torch
87
+ if pipe is None or pipe.model.name_or_path != model_name:
88
+ pipe = pipeline("text-generation", model=model_name)
89
 
90
  # Build prompt as plain text
91
  prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
 
102
  yield response.strip()
103
 
104
  else:
105
+ print(f"[MODE] api - {model_name}")
106
 
107
  if hf_token is None or not getattr(hf_token, "token", None):
108
  yield "⚠️ Please log in with your Hugging Face account first."
109
  return
110
 
111
+ client = InferenceClient(token=hf_token.token, model=model_name)
112
 
113
  for chunk in client.chat_completion(
114
  messages,
 
132
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
133
  gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
134
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
135
+ gr.Radio(choices=MODEL_OPTIONS, label="Select Model", value=MODEL_OPTIONS[0]),
136
  ],
137
  type="messages",
138
  )