Spaces:
Sleeping
Sleeping
Update backend.py
Browse files- backend.py +4 -3
backend.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# backend.py — FINAL HARDENED VERSION
|
| 2 |
import sqlite3
|
| 3 |
import os
|
| 4 |
import json
|
|
@@ -74,7 +74,8 @@ def load_model(model_name):
|
|
| 74 |
if model_name in _MODEL_CACHE: return _MODEL_CACHE[model_name]
|
| 75 |
print(f"Loading model: {model_name}...")
|
| 76 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 77 |
-
|
|
|
|
| 78 |
_MODEL_CACHE[model_name] = (tokenizer, model)
|
| 79 |
print(f"Model {model_name} loaded and cached.")
|
| 80 |
return tokenizer, model
|
|
@@ -140,7 +141,7 @@ def generate_with_model(role: str, prompt: str) -> str:
|
|
| 140 |
print(f"Error during model generation for role {role}: {e}")
|
| 141 |
return f'{{"error": "Failed to generate response: {str(e)}"}}'
|
| 142 |
|
| 143 |
-
# ------------------------------ THE AGENT CHAIN EXECUTOR ------------------------------
|
| 144 |
def run_agent_chain(project_id, user_id, initial_prompt):
|
| 145 |
project_dir = get_project_dir(user_id, project_id)
|
| 146 |
log_entries = []
|
|
|
|
| 1 |
+
# backend.py — FINAL HARDENED VERSION v1.1
|
| 2 |
import sqlite3
|
| 3 |
import os
|
| 4 |
import json
|
|
|
|
| 74 |
if model_name in _MODEL_CACHE: return _MODEL_CACHE[model_name]
|
| 75 |
print(f"Loading model: {model_name}...")
|
| 76 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 77 |
+
# --- THIS IS THE FIX FOR THE WARNING ---
|
| 78 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto", device_map="auto", trust_remote_code=True, attn_implementation="eager")
|
| 79 |
_MODEL_CACHE[model_name] = (tokenizer, model)
|
| 80 |
print(f"Model {model_name} loaded and cached.")
|
| 81 |
return tokenizer, model
|
|
|
|
| 141 |
print(f"Error during model generation for role {role}: {e}")
|
| 142 |
return f'{{"error": "Failed to generate response: {str(e)}"}}'
|
| 143 |
|
| 144 |
+
# ------------------------------ THE AGENT CHAIN EXECUTOR (REWRITTEN FOR RELIABILITY) ------------------------------
|
| 145 |
def run_agent_chain(project_id, user_id, initial_prompt):
|
| 146 |
project_dir = get_project_dir(user_id, project_id)
|
| 147 |
log_entries = []
|