Spaces:
Build error
Build error
null and void
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -27,15 +27,35 @@ class ConversationManager:
|
|
| 27 |
return self.models[model_name]
|
| 28 |
except Exception as e:
|
| 29 |
print(f"Failed to load model {model_name}: {e}")
|
|
|
|
|
|
|
| 30 |
return None
|
| 31 |
|
| 32 |
def generate_response(self, model_name, prompt):
|
| 33 |
model, tokenizer = self.load_model(model_name)
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
with torch.no_grad():
|
| 36 |
outputs = model.generate(**inputs, max_length=200, num_return_sequences=1, do_sample=True)
|
| 37 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
def add_to_conversation(self, model_name, response):
|
| 40 |
self.conversation.append((model_name, response))
|
| 41 |
if "task complete?" in response.lower(): # Check for task completion marker
|
|
@@ -72,48 +92,56 @@ def get_model(dropdown, custom):
|
|
| 72 |
return (model, model) # Return a tuple (label, value)
|
| 73 |
|
| 74 |
def chat(model1, model2, user_input, history, inserted_response=""):
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
if not manager.conversation:
|
| 79 |
-
manager.initial_prompt = user_input
|
| 80 |
-
manager.clear_conversation()
|
| 81 |
-
manager.add_to_conversation("User", user_input)
|
| 82 |
-
|
| 83 |
-
models = [model1, model2]
|
| 84 |
-
current_model_index = 0 if manager.current_model in ["User", "Model 2"] else 1
|
| 85 |
-
|
| 86 |
-
while not manager.task_complete: # Continue until task is complete
|
| 87 |
-
if manager.is_paused:
|
| 88 |
-
yield history, "Conversation paused."
|
| 89 |
-
return
|
| 90 |
-
|
| 91 |
-
model = models[current_model_index]
|
| 92 |
-
manager.current_model = model
|
| 93 |
-
|
| 94 |
-
if inserted_response and current_model_index == 0:
|
| 95 |
-
response = inserted_response
|
| 96 |
-
inserted_response = ""
|
| 97 |
-
else:
|
| 98 |
-
prompt = manager.get_conversation_history() + "\n\nPlease continue the conversation. If you believe the task is complete, end your response with 'Task complete?'"
|
| 99 |
-
response = manager.generate_response(model, prompt)
|
| 100 |
|
| 101 |
-
manager.
|
| 102 |
-
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
| 107 |
|
| 108 |
-
|
|
|
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
| 117 |
|
| 118 |
def user_satisfaction(satisfied, history):
|
| 119 |
if satisfied.lower() == 'yes':
|
|
|
|
| 27 |
return self.models[model_name]
|
| 28 |
except Exception as e:
|
| 29 |
print(f"Failed to load model {model_name}: {e}")
|
| 30 |
+
print(f"Error type: {type(e).__name__}")
|
| 31 |
+
print(f"Error details: {str(e)}")
|
| 32 |
return None
|
| 33 |
|
| 34 |
def generate_response(self, model_name, prompt):
|
| 35 |
model, tokenizer = self.load_model(model_name)
|
| 36 |
+
|
| 37 |
+
# Format the prompt based on the model
|
| 38 |
+
if "llama" in model_name.lower():
|
| 39 |
+
formatted_prompt = self.format_llama2_prompt(prompt)
|
| 40 |
+
else:
|
| 41 |
+
formatted_prompt = self.format_general_prompt(prompt)
|
| 42 |
+
|
| 43 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt", max_length=1024, truncation=True)
|
| 44 |
with torch.no_grad():
|
| 45 |
outputs = model.generate(**inputs, max_length=200, num_return_sequences=1, do_sample=True)
|
| 46 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 47 |
|
| 48 |
+
def format_llama2_prompt(self, prompt):
|
| 49 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
| 50 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 51 |
+
system_prompt = "You are a helpful AI assistant. Please provide a concise and relevant response."
|
| 52 |
+
|
| 53 |
+
formatted_prompt = f"{B_INST} {B_SYS}{system_prompt}{E_SYS}{prompt.strip()} {E_INST}"
|
| 54 |
+
return formatted_prompt
|
| 55 |
+
|
| 56 |
+
def format_general_prompt(self, prompt):
|
| 57 |
+
# A general format that might work for other models
|
| 58 |
+
return f"Human: {prompt.strip()}\n\nAssistant:"
|
| 59 |
def add_to_conversation(self, model_name, response):
|
| 60 |
self.conversation.append((model_name, response))
|
| 61 |
if "task complete?" in response.lower(): # Check for task completion marker
|
|
|
|
| 92 |
return (model, model) # Return a tuple (label, value)
|
| 93 |
|
| 94 |
def chat(model1, model2, user_input, history, inserted_response=""):
|
| 95 |
+
try:
|
| 96 |
+
model1 = get_model(model1, model1_custom.value)[0]
|
| 97 |
+
model2 = get_model(model2, model2_custom.value)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
if not manager.load_model(model1) or not manager.load_model(model2):
|
| 100 |
+
return "Error: Failed to load one or both models. Please check the model names and try again.", ""
|
| 101 |
|
| 102 |
+
if not manager.conversation:
|
| 103 |
+
manager.initial_prompt = user_input
|
| 104 |
+
manager.clear_conversation()
|
| 105 |
+
manager.add_to_conversation("User", user_input)
|
| 106 |
|
| 107 |
+
models = [model1, model2]
|
| 108 |
+
current_model_index = 0 if manager.current_model in ["User", "Model 2"] else 1
|
| 109 |
|
| 110 |
+
while not manager.task_complete:
|
| 111 |
+
if manager.is_paused:
|
| 112 |
+
yield history, "Conversation paused."
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
model = models[current_model_index]
|
| 116 |
+
manager.current_model = model
|
| 117 |
+
|
| 118 |
+
if inserted_response and current_model_index == 0:
|
| 119 |
+
response = inserted_response
|
| 120 |
+
inserted_response = ""
|
| 121 |
+
else:
|
| 122 |
+
conversation_history = manager.get_conversation_history()
|
| 123 |
+
prompt = f"{conversation_history}\n\nPlease continue the conversation. If you believe the task is complete, end your response with 'Task complete?'"
|
| 124 |
+
response = manager.generate_response(model, prompt)
|
| 125 |
+
|
| 126 |
+
manager.add_to_conversation(model, response)
|
| 127 |
+
history = manager.get_conversation_history()
|
| 128 |
+
|
| 129 |
+
for i in range(manager.delay, 0, -1):
|
| 130 |
+
yield history, f"{model} is writing... {i}"
|
| 131 |
+
time.sleep(1)
|
| 132 |
+
|
| 133 |
+
yield history, ""
|
| 134 |
+
|
| 135 |
+
if manager.task_complete:
|
| 136 |
+
yield history, "Models believe the task is complete. Are you satisfied with the result? (Yes/No)"
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
current_model_index = (current_model_index + 1) % 2
|
| 140 |
|
| 141 |
+
return history, "Conversation completed."
|
| 142 |
+
except Exception as e:
|
| 143 |
+
print(f"Error in chat function: {str(e)}")
|
| 144 |
+
return f"An error occurred: {str(e)}", ""
|
| 145 |
|
| 146 |
def user_satisfaction(satisfied, history):
|
| 147 |
if satisfied.lower() == 'yes':
|