Keeby-smilyai commited on
Commit
033fdf6
·
verified ·
1 Parent(s): de9422e

Update backend.py

Browse files
Files changed (1) hide show
  1. backend.py +4 -3
backend.py CHANGED
@@ -1,4 +1,4 @@
1
- # backend.py — FINAL HARDENED VERSION
2
  import sqlite3
3
  import os
4
  import json
@@ -74,7 +74,8 @@ def load_model(model_name):
74
  if model_name in _MODEL_CACHE: return _MODEL_CACHE[model_name]
75
  print(f"Loading model: {model_name}...")
76
  tokenizer = AutoTokenizer.from_pretrained(model_name)
77
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True, attn_implementation="eager")
 
78
  _MODEL_CACHE[model_name] = (tokenizer, model)
79
  print(f"Model {model_name} loaded and cached.")
80
  return tokenizer, model
@@ -140,7 +141,7 @@ def generate_with_model(role: str, prompt: str) -> str:
140
  print(f"Error during model generation for role {role}: {e}")
141
  return f'{{"error": "Failed to generate response: {str(e)}"}}'
142
 
143
- # ------------------------------ THE AGENT CHAIN EXECUTOR ------------------------------
144
  def run_agent_chain(project_id, user_id, initial_prompt):
145
  project_dir = get_project_dir(user_id, project_id)
146
  log_entries = []
 
1
+ # backend.py — FINAL HARDENED VERSION v1.1
2
  import sqlite3
3
  import os
4
  import json
 
74
  if model_name in _MODEL_CACHE: return _MODEL_CACHE[model_name]
75
  print(f"Loading model: {model_name}...")
76
  tokenizer = AutoTokenizer.from_pretrained(model_name)
77
+ # --- THIS IS THE FIX FOR THE WARNING ---
78
+ model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto", device_map="auto", trust_remote_code=True, attn_implementation="eager")
79
  _MODEL_CACHE[model_name] = (tokenizer, model)
80
  print(f"Model {model_name} loaded and cached.")
81
  return tokenizer, model
 
141
  print(f"Error during model generation for role {role}: {e}")
142
  return f'{{"error": "Failed to generate response: {str(e)}"}}'
143
 
144
+ # ------------------------------ THE AGENT CHAIN EXECUTOR (REWRITTEN FOR RELIABILITY) ------------------------------
145
  def run_agent_chain(project_id, user_id, initial_prompt):
146
  project_dir = get_project_dir(user_id, project_id)
147
  log_entries = []