Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,17 +16,27 @@ current_model_name = None
|
|
| 16 |
# Load selected model
|
| 17 |
def load_model(model_name):
|
| 18 |
global tokenizer, model, current_model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
full_model_name = f"MaxLSB/{model_name}"
|
|
|
|
| 20 |
tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
|
| 21 |
model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
|
| 22 |
model.eval()
|
| 23 |
current_model_name = model_name
|
|
|
|
| 24 |
|
| 25 |
# Initialize default model
|
| 26 |
load_model("LeCarnet-8M")
|
| 27 |
|
| 28 |
# Streaming generation function
|
| 29 |
-
def respond(message, max_tokens, temperature, top_p):
|
|
|
|
|
|
|
|
|
|
| 30 |
inputs = tokenizer(message, return_tensors="pt")
|
| 31 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
|
| 32 |
|
|
@@ -57,18 +67,17 @@ def user(message, chat_history):
|
|
| 57 |
chat_history.append([message, None])
|
| 58 |
return "", chat_history
|
| 59 |
|
| 60 |
-
# Bot response handler
|
| 61 |
-
def bot(chatbot, max_tokens, temperature, top_p):
|
| 62 |
message = chatbot[-1][0]
|
| 63 |
-
response_generator = respond(message, max_tokens, temperature, top_p)
|
| 64 |
for response in response_generator:
|
| 65 |
chatbot[-1][1] = response
|
| 66 |
yield chatbot
|
| 67 |
|
| 68 |
-
# Model selector handler
|
| 69 |
def update_model(model_name):
|
| 70 |
load_model(model_name)
|
| 71 |
-
# Return the model_name directly instead of using gr.Dropdown.update()
|
| 72 |
return model_name
|
| 73 |
|
| 74 |
# Clear chat handler
|
|
@@ -84,7 +93,6 @@ with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
|
|
| 84 |
</div>
|
| 85 |
""")
|
| 86 |
|
| 87 |
-
# Create the msg_input early, but don't render it yet
|
| 88 |
msg_input = gr.Textbox(
|
| 89 |
placeholder="Il était une fois un petit garçon",
|
| 90 |
label="User Input",
|
|
@@ -118,14 +126,13 @@ with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
|
|
| 118 |
bubble_full_width=False,
|
| 119 |
height=500
|
| 120 |
)
|
| 121 |
-
# Now render the msg_input inside the right column, below the chatbot
|
| 122 |
msg_input.render()
|
| 123 |
|
| 124 |
# Event Handlers
|
| 125 |
model_selector.change(
|
| 126 |
fn=update_model,
|
| 127 |
inputs=[model_selector],
|
| 128 |
-
outputs=[model_selector],
|
| 129 |
)
|
| 130 |
|
| 131 |
msg_input.submit(
|
|
@@ -135,7 +142,7 @@ with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
|
|
| 135 |
queue=False
|
| 136 |
).then(
|
| 137 |
fn=bot,
|
| 138 |
-
inputs=[chatbot, max_tokens, temperature, top_p],
|
| 139 |
outputs=[chatbot]
|
| 140 |
)
|
| 141 |
|
|
|
|
| 16 |
# Load selected model
|
| 17 |
def load_model(model_name):
|
| 18 |
global tokenizer, model, current_model_name
|
| 19 |
+
|
| 20 |
+
# Only load if it's a different model
|
| 21 |
+
if current_model_name == model_name:
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
full_model_name = f"MaxLSB/{model_name}"
|
| 25 |
+
print(f"Loading model: {full_model_name}")
|
| 26 |
tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
|
| 27 |
model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
|
| 28 |
model.eval()
|
| 29 |
current_model_name = model_name
|
| 30 |
+
print(f"Model loaded: {current_model_name}")
|
| 31 |
|
| 32 |
# Initialize default model
|
| 33 |
load_model("LeCarnet-8M")
|
| 34 |
|
| 35 |
# Streaming generation function
|
| 36 |
+
def respond(message, max_tokens, temperature, top_p, selected_model):
|
| 37 |
+
# Ensure the correct model is loaded before generation
|
| 38 |
+
load_model(selected_model)
|
| 39 |
+
|
| 40 |
inputs = tokenizer(message, return_tensors="pt")
|
| 41 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
|
| 42 |
|
|
|
|
| 67 |
chat_history.append([message, None])
|
| 68 |
return "", chat_history
|
| 69 |
|
| 70 |
+
# Bot response handler - UPDATED to pass selected model
|
| 71 |
+
def bot(chatbot, max_tokens, temperature, top_p, selected_model):
|
| 72 |
message = chatbot[-1][0]
|
| 73 |
+
response_generator = respond(message, max_tokens, temperature, top_p, selected_model)
|
| 74 |
for response in response_generator:
|
| 75 |
chatbot[-1][1] = response
|
| 76 |
yield chatbot
|
| 77 |
|
| 78 |
+
# Model selector handler
|
| 79 |
def update_model(model_name):
|
| 80 |
load_model(model_name)
|
|
|
|
| 81 |
return model_name
|
| 82 |
|
| 83 |
# Clear chat handler
|
|
|
|
| 93 |
</div>
|
| 94 |
""")
|
| 95 |
|
|
|
|
| 96 |
msg_input = gr.Textbox(
|
| 97 |
placeholder="Il était une fois un petit garçon",
|
| 98 |
label="User Input",
|
|
|
|
| 126 |
bubble_full_width=False,
|
| 127 |
height=500
|
| 128 |
)
|
|
|
|
| 129 |
msg_input.render()
|
| 130 |
|
| 131 |
# Event Handlers
|
| 132 |
model_selector.change(
|
| 133 |
fn=update_model,
|
| 134 |
inputs=[model_selector],
|
| 135 |
+
outputs=[model_selector],
|
| 136 |
)
|
| 137 |
|
| 138 |
msg_input.submit(
|
|
|
|
| 142 |
queue=False
|
| 143 |
).then(
|
| 144 |
fn=bot,
|
| 145 |
+
inputs=[chatbot, max_tokens, temperature, top_p, model_selector], # Pass model_selector
|
| 146 |
outputs=[chatbot]
|
| 147 |
)
|
| 148 |
|