Suguru1846 commited on
Commit
3f8d4b3
Β·
verified Β·
1 Parent(s): e70f7f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -70
app.py CHANGED
@@ -1,88 +1,73 @@
1
  import os
 
2
  import torch
3
  from fastapi import FastAPI
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
- from peft import PeftModel
6
 
7
- # Disable Triton & set proper environment variables for HF Spaces
 
 
 
8
  os.environ["TRITON_DISABLE"] = "1"
9
- os.environ["BNB_DISABLE_TRITON"] = "1"
10
  os.environ["USE_TORCH"] = "1"
11
  os.environ["BITSANDBYTES_NOWELCOME"] = "1"
 
 
 
12
 
13
- # Set writable cache locations (HF Spaces needs explicit cache dirs)
14
- HF_CACHE_DIR = "/app/.cache/huggingface"
15
- TORCH_CACHE_DIR = "/app/.cache/torch"
16
-
17
- os.environ["HF_HOME"] = HF_CACHE_DIR
18
- os.environ["TRANSFORMERS_CACHE"] = HF_CACHE_DIR
19
- os.environ["TORCH_HOME"] = TORCH_CACHE_DIR
20
-
21
- # Create necessary directories & fix permissions
22
- for cache_dir in [HF_CACHE_DIR, TORCH_CACHE_DIR, "/tmp"]:
23
- os.makedirs(cache_dir, exist_ok=True)
24
- os.chmod(cache_dir, 0o777) # Ensure all users can read/write
25
-
26
- # Initialize FastAPI
27
  app = FastAPI()
28
 
29
- # Load base model & tokenizer
30
  base_model_name = "unsloth/Llama-3.2-3B-Instruct"
31
- adapter_name = "Suguru1846/lora_model_counseling_4bit"
32
-
33
- print("πŸš€ Loading base model...")
34
- try:
35
- tokenizer = AutoTokenizer.from_pretrained(
36
- base_model_name, cache_dir=HF_CACHE_DIR, trust_remote_code=True
37
- )
38
-
39
- model = AutoModelForCausalLM.from_pretrained(
40
- base_model_name,
41
- torch_dtype=torch.float16,
42
- device_map="auto",
43
- cache_dir=HF_CACHE_DIR,
44
- trust_remote_code=True
45
- )
46
 
47
- print("βœ… Base model loaded successfully")
 
 
 
48
 
49
- # Try loading LoRA adapter
50
- print("πŸ”„ Attempting to load LoRA adapter...")
51
  try:
52
- model = PeftModel.from_pretrained(
53
- model,
54
- adapter_name,
55
- device_map="auto",
56
- cache_dir=HF_CACHE_DIR
57
- )
58
- print("βœ… LoRA adapter loaded successfully")
59
- except Exception as adapter_error:
60
- print(f"⚠️ LoRA adapter loading failed: {adapter_error}")
61
- print("⚠️ Running with base model only")
62
-
63
- except Exception as model_error:
64
- print(f"❌ Error loading base model: {model_error}")
65
- print("πŸ”„ Falling back to OPT-350M model...")
66
-
67
- # Fallback model in case of failure
68
- base_model_name = "facebook/opt-350m"
69
-
70
- tokenizer = AutoTokenizer.from_pretrained(
71
- base_model_name, cache_dir=HF_CACHE_DIR
72
- )
73
-
74
- model = AutoModelForCausalLM.from_pretrained(
75
- base_model_name,
76
- torch_dtype=torch.float16,
77
- device_map="auto",
78
- cache_dir=HF_CACHE_DIR
79
- )
80
 
81
- print("βœ… Using fallback OPT-350M model")
 
82
 
83
  @app.post("/generate")
84
  async def generate_text(prompt: str, max_tokens: int = 50):
85
- """Generates text using the loaded model."""
86
  try:
87
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
88
  outputs = model.generate(
@@ -95,11 +80,10 @@ async def generate_text(prompt: str, max_tokens: int = 50):
95
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
96
  return {"response": response}
97
  except Exception as e:
98
- print(f"❌ Error generating text: {str(e)}")
99
  return {"error": str(e)}
100
 
101
  @app.get("/")
102
  async def root():
103
- """Health check endpoint."""
104
- model_type = "Base Llama-3.2 with adapter" if hasattr(model, "peft_config") else "Base Llama-3.2 only"
105
- return {"message": f"AI Model is Running! Using: {model_type}"}
 
1
  import os
2
+ import sys
3
  import torch
4
  from fastapi import FastAPI
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
6
 
7
+ # Create a directory with proper permissions
8
+ os.makedirs("/tmp/hf_cache", exist_ok=True)
9
+
10
+ # Set environment variables
11
  os.environ["TRITON_DISABLE"] = "1"
12
+ os.environ["BNB_DISABLE_TRITON"] = "1"
13
  os.environ["USE_TORCH"] = "1"
14
  os.environ["BITSANDBYTES_NOWELCOME"] = "1"
15
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
16
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
17
+ os.environ["TORCH_HOME"] = "/tmp/hf_cache"
18
 
19
+ # FastAPI app
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  app = FastAPI()
21
 
22
+ # Load the base model
23
  base_model_name = "unsloth/Llama-3.2-3B-Instruct"
24
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name, cache_dir="/tmp/hf_cache")
25
+ base_model = AutoModelForCausalLM.from_pretrained(
26
+ base_model_name,
27
+ torch_dtype=torch.float16,
28
+ device_map="auto",
29
+ cache_dir="/tmp/hf_cache"
30
+ )
 
 
 
 
 
 
 
 
31
 
32
+ # Try different PEFT versions
33
+ adapter_name = "Suguru1846/lora_model_counseling_4bit"
34
+ model = base_model # Default to base model
35
+ adapter_loaded = False
36
 
37
+ peft_versions = ["0.3.0", "0.4.0", "0.5.0"]
38
+ for version in peft_versions:
39
  try:
40
+ # Add PEFT version to path
41
+ version_path = f"/app/peft_versions/{version}"
42
+ if version_path not in sys.path:
43
+ sys.path.insert(0, version_path)
44
+
45
+ # Import PEFT from this version
46
+ from peft import PeftModel
47
+
48
+ # Try loading with this version
49
+ print(f"Trying PEFT version {version}")
50
+ adapter_model = PeftModel.from_pretrained(base_model, adapter_name)
51
+ model = adapter_model
52
+ adapter_loaded = True
53
+ print(f"Success! Adapter loaded with PEFT version {version}")
54
+ break
55
+ except Exception as e:
56
+ print(f"Failed with PEFT version {version}: {str(e)}")
57
+ # Remove this version from path
58
+ if version_path in sys.path:
59
+ sys.path.remove(version_path)
60
+
61
+ # Reset imports to try next version
62
+ if "peft" in sys.modules:
63
+ del sys.modules["peft"]
 
 
 
 
64
 
65
+ if not adapter_loaded:
66
+ print("Could not load adapter with any PEFT version. Using base model only.")
67
 
68
  @app.post("/generate")
69
  async def generate_text(prompt: str, max_tokens: int = 50):
70
+ """Generates text using the Llama model."""
71
  try:
72
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
73
  outputs = model.generate(
 
80
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
81
  return {"response": response}
82
  except Exception as e:
83
+ print(f"Error generating text: {str(e)}")
84
  return {"error": str(e)}
85
 
86
  @app.get("/")
87
  async def root():
88
+ model_status = "with adapter" if adapter_loaded else "base model only"
89
+ return {"message": f"Model is running ({model_status})"}