Suguru1846 commited on
Commit
1f4f76d
Β·
verified Β·
1 Parent(s): 1343cce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -41
app.py CHANGED
@@ -2,75 +2,87 @@ import os
2
  import torch
3
  from fastapi import FastAPI
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
- from huggingface_hub import hf_hub_download
6
 
7
- # Set these environment variables before importing any libraries
8
  os.environ["TRITON_DISABLE"] = "1"
9
- os.environ["BNB_DISABLE_TRITON"] = "1"
10
  os.environ["USE_TORCH"] = "1"
11
  os.environ["BITSANDBYTES_NOWELCOME"] = "1"
12
 
13
- # Set writable cache locations
14
- os.environ["HF_HOME"] = "/app/.cache/huggingface"
15
- os.environ["TRANSFORMERS_CACHE"] = "/app/.cache/huggingface"
16
- os.environ["TORCH_HOME"] = "/app/.cache/torch"
17
 
18
- # FastAPI app instance
 
 
 
 
 
 
 
 
 
19
  app = FastAPI()
20
 
21
- # Load the base model and tokenizer
22
  base_model_name = "unsloth/Llama-3.2-3B-Instruct"
23
- tokenizer = AutoTokenizer.from_pretrained(
24
- base_model_name,
25
- cache_dir="/app/.cache/huggingface"
26
- )
27
 
 
28
  try:
29
- # Load the base model first
 
 
 
30
  model = AutoModelForCausalLM.from_pretrained(
31
  base_model_name,
32
  torch_dtype=torch.float16,
33
  device_map="auto",
34
- cache_dir="/app/.cache/huggingface"
 
35
  )
36
-
37
- # Now try loading and merging your adapter
 
 
 
38
  try:
39
- # Import PEFT after model loading to avoid conflicts
40
- from peft import PeftModel, PeftConfig
41
-
42
- adapter_name = "Suguru1846/lora_model_counseling_4bit"
43
-
44
- # First try the standard approach
45
  model = PeftModel.from_pretrained(
46
  model,
47
  adapter_name,
48
- device_map="auto"
 
49
  )
50
- print("Successfully loaded adapter with standard method")
51
  except Exception as adapter_error:
52
- print(f"Standard adapter loading failed: {str(adapter_error)}")
53
- print("Model is still running with base model only")
54
-
55
  except Exception as model_error:
56
- print(f"Error loading base model: {str(model_error)}")
57
- # Fallback to a working model
58
- model_name = "facebook/opt-350m"
 
 
 
 
 
 
 
59
  model = AutoModelForCausalLM.from_pretrained(
60
- model_name,
61
  torch_dtype=torch.float16,
62
  device_map="auto",
63
- cache_dir="/app/.cache/huggingface"
64
- )
65
- tokenizer = AutoTokenizer.from_pretrained(
66
- model_name,
67
- cache_dir="/app/.cache/huggingface"
68
  )
69
- print("Fell back to OPT-350M model")
 
70
 
71
  @app.post("/generate")
72
  async def generate_text(prompt: str, max_tokens: int = 50):
73
- """Generates text using the model."""
74
  try:
75
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
76
  outputs = model.generate(
@@ -83,10 +95,11 @@ async def generate_text(prompt: str, max_tokens: int = 50):
83
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
84
  return {"response": response}
85
  except Exception as e:
86
- print(f"Error generating text: {str(e)}")
87
  return {"error": str(e)}
88
 
89
  @app.get("/")
90
  async def root():
 
91
  model_type = "Base Llama-3.2 with adapter" if hasattr(model, "peft_config") else "Base Llama-3.2 only"
92
- return {"message": f"AI Model is Running! Using: {model_type}"}
 
2
  import torch
3
  from fastapi import FastAPI
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from peft import PeftModel
6
 
7
+ # Disable Triton & set proper environment variables for HF Spaces
8
  os.environ["TRITON_DISABLE"] = "1"
9
+ os.environ["BNB_DISABLE_TRITON"] = "1"
10
  os.environ["USE_TORCH"] = "1"
11
  os.environ["BITSANDBYTES_NOWELCOME"] = "1"
12
 
13
+ # Set writable cache locations (HF Spaces needs explicit cache dirs)
14
+ HF_CACHE_DIR = "/app/.cache/huggingface"
15
+ TORCH_CACHE_DIR = "/app/.cache/torch"
 
16
 
17
+ os.environ["HF_HOME"] = HF_CACHE_DIR
18
+ os.environ["TRANSFORMERS_CACHE"] = HF_CACHE_DIR
19
+ os.environ["TORCH_HOME"] = TORCH_CACHE_DIR
20
+
21
+ # Create necessary directories & fix permissions
22
+ for cache_dir in [HF_CACHE_DIR, TORCH_CACHE_DIR, "/tmp"]:
23
+ os.makedirs(cache_dir, exist_ok=True)
24
+ os.chmod(cache_dir, 0o777) # Ensure all users can read/write
25
+
26
+ # Initialize FastAPI
27
  app = FastAPI()
28
 
29
+ # Load base model & tokenizer
30
  base_model_name = "unsloth/Llama-3.2-3B-Instruct"
31
+ adapter_name = "Suguru1846/lora_model_counseling_4bit"
 
 
 
32
 
33
+ print("πŸš€ Loading base model...")
34
  try:
35
+ tokenizer = AutoTokenizer.from_pretrained(
36
+ base_model_name, cache_dir=HF_CACHE_DIR, trust_remote_code=True
37
+ )
38
+
39
  model = AutoModelForCausalLM.from_pretrained(
40
  base_model_name,
41
  torch_dtype=torch.float16,
42
  device_map="auto",
43
+ cache_dir=HF_CACHE_DIR,
44
+ trust_remote_code=True
45
  )
46
+
47
+ print("βœ… Base model loaded successfully")
48
+
49
+ # Try loading LoRA adapter
50
+ print("πŸ”„ Attempting to load LoRA adapter...")
51
  try:
 
 
 
 
 
 
52
  model = PeftModel.from_pretrained(
53
  model,
54
  adapter_name,
55
+ device_map="auto",
56
+ cache_dir=HF_CACHE_DIR
57
  )
58
+ print("βœ… LoRA adapter loaded successfully")
59
  except Exception as adapter_error:
60
+ print(f"⚠️ LoRA adapter loading failed: {adapter_error}")
61
+ print("⚠️ Running with base model only")
62
+
63
  except Exception as model_error:
64
+ print(f"❌ Error loading base model: {model_error}")
65
+ print("πŸ”„ Falling back to OPT-350M model...")
66
+
67
+ # Fallback model in case of failure
68
+ base_model_name = "facebook/opt-350m"
69
+
70
+ tokenizer = AutoTokenizer.from_pretrained(
71
+ base_model_name, cache_dir=HF_CACHE_DIR
72
+ )
73
+
74
  model = AutoModelForCausalLM.from_pretrained(
75
+ base_model_name,
76
  torch_dtype=torch.float16,
77
  device_map="auto",
78
+ cache_dir=HF_CACHE_DIR
 
 
 
 
79
  )
80
+
81
+ print("βœ… Using fallback OPT-350M model")
82
 
83
  @app.post("/generate")
84
  async def generate_text(prompt: str, max_tokens: int = 50):
85
+ """Generates text using the loaded model."""
86
  try:
87
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
88
  outputs = model.generate(
 
95
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
96
  return {"response": response}
97
  except Exception as e:
98
+ print(f"❌ Error generating text: {str(e)}")
99
  return {"error": str(e)}
100
 
101
  @app.get("/")
102
  async def root():
103
+ """Health check endpoint."""
104
  model_type = "Base Llama-3.2 with adapter" if hasattr(model, "peft_config") else "Base Llama-3.2 only"
105
+ return {"message": f"AI Model is Running! Using: {model_type}"}