chaymaemerhrioui commited on
Commit
5fc03fb
·
verified ·
1 Parent(s): 1f8e1c1

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +89 -26
main.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  import os
7
 
8
  # Initialize FastAPI app
9
- app = FastAPI(title="AI Model API", description="API for Description Generation")
10
 
11
  # Global variables for models
12
  description_model = None
@@ -30,40 +30,98 @@ def authenticate_huggingface():
30
 
31
  def load_models():
32
  global description_model, description_tokenizer
33
-
34
  try:
35
  print("Loading models...")
36
  if not authenticate_huggingface():
37
  print("⚠️ Warning: Not authenticated with Hugging Face.")
38
 
39
- fine_tuned_model = "chaymaemerhrioui/Brain_Model_ACC_Trainer"
40
-
 
 
41
  print(f"Loading fine-tuned model: {fine_tuned_model}")
 
42
 
 
43
  try:
44
- # Load the fine-tuned model
45
- description_model = AutoModelForCausalLM.from_pretrained(
46
- fine_tuned_model,
47
- torch_dtype=torch.float16,
48
- device_map="auto",
49
- trust_remote_code=True
50
- )
51
-
52
- # Load the tokenizer associated with the fine-tuned model
53
- description_tokenizer = AutoTokenizer.from_pretrained(fine_tuned_model)
54
-
55
- print("✅ Successfully loaded the fine-tuned model and its tokenizer!")
56
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  except Exception as e:
58
- print(f"Failed to load fine-tuned model: {e}")
59
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # Set up tokenizer padding
62
  if description_tokenizer.pad_token is None:
63
  description_tokenizer.pad_token = description_tokenizer.eos_token
64
 
65
  print("✅ Model loading completed successfully!")
66
-
67
  except Exception as e:
68
  print(f"❌ All loading methods failed: {e}")
69
  raise e
@@ -92,7 +150,7 @@ async def health_check():
92
  "description_tokenizer": description_tokenizer is not None
93
  }
94
  return {
95
- "status": "healthy" if all(model_status.values()) else "partial",
96
  "models": model_status
97
  }
98
 
@@ -100,9 +158,10 @@ async def health_check():
100
  async def generate_description(item: DescriptionItem):
101
  if description_model is None or description_tokenizer is None:
102
  raise HTTPException(
103
- status_code=503,
104
  detail="Description model not available"
105
  )
 
106
  try:
107
  # Tokenize the input prompt
108
  inputs = description_tokenizer(
@@ -112,9 +171,11 @@ async def generate_description(item: DescriptionItem):
112
  truncation=True,
113
  max_length=512
114
  )
 
115
  # Move inputs to model device
116
  if hasattr(description_model, 'device'):
117
  inputs = {k: v.to(description_model.device) for k, v in inputs.items()}
 
118
  # Generate response
119
  with torch.no_grad():
120
  outputs = description_model.generate(
@@ -126,16 +187,18 @@ async def generate_description(item: DescriptionItem):
126
  pad_token_id=description_tokenizer.eos_token_id,
127
  repetition_penalty=1.1
128
  )
 
129
  # Decode only the new tokens
130
  input_length = inputs['input_ids'].shape[1]
131
  description = description_tokenizer.decode(
132
- outputs[0][input_length:],
133
  skip_special_tokens=True
134
  )
135
-
136
  return {"description": description.strip()}
 
137
  except Exception as e:
138
  raise HTTPException(
139
- status_code=500,
140
  detail=f"Error generating description: {str(e)}"
141
- )
 
6
  import os
7
 
8
  # Initialize FastAPI app
9
+ app = FastAPI(title="AI Model API", description="API for Description and UML Generation")
10
 
11
  # Global variables for models
12
  description_model = None
 
30
 
31
  def load_models():
32
  global description_model, description_tokenizer
33
+
34
  try:
35
  print("Loading models...")
36
  if not authenticate_huggingface():
37
  print("⚠️ Warning: Not authenticated with Hugging Face.")
38
 
39
+ # Model configuration
40
+ fine_tuned_model = "chaymaemerhrioui/Brain_Model_ACC_unsloth"
41
+ base_model = "unsloth/mistral-7b-bnb-4bit"
42
+
43
  print(f"Loading fine-tuned model: {fine_tuned_model}")
44
+ print(f"Base model: {base_model}")
45
 
46
+ # Method 1: Try loading as PEFT/LoRA adapter
47
  try:
48
+ print("Attempting PEFT/LoRA loading...")
49
+ from peft import PeftModel, AutoPeftModelForCausalLM
50
+
51
+ # Option 1a: Use AutoPeftModelForCausalLM (handles everything automatically)
52
+ try:
53
+ print("Using AutoPeftModelForCausalLM...")
54
+ description_model = AutoPeftModelForCausalLM.from_pretrained(
55
+ fine_tuned_model,
56
+ torch_dtype=torch.float16,
57
+ device_map="auto",
58
+ trust_remote_code=True
59
+ )
60
+
61
+ # Get the base model tokenizer
62
+ base_model_name = description_model.peft_config.base_model_name_or_path
63
+ description_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
64
+ print("✅ Successfully loaded with AutoPeftModelForCausalLM!")
65
+
66
+ except Exception as e1:
67
+ print(f"AutoPeftModelForCausalLM failed: {e1}")
68
+
69
+ # Option 1b: Manual PEFT loading
70
+ print("Trying manual PEFT loading...")
71
+ print("Loading base model...")
72
+ description_tokenizer = AutoTokenizer.from_pretrained(base_model)
73
+ base_model_obj = AutoModelForCausalLM.from_pretrained(
74
+ base_model,
75
+ torch_dtype=torch.float16,
76
+ device_map="auto"
77
+ )
78
+
79
+ print("Loading PEFT adapter...")
80
+ description_model = PeftModel.from_pretrained(
81
+ base_model_obj,
82
+ fine_tuned_model
83
+ )
84
+ print("✅ Successfully loaded with manual PEFT!")
85
+
86
  except Exception as e:
87
+ print(f"PEFT loading failed: {e}")
88
+
89
+ # Method 2: Try loading as regular fine-tuned model with base model tokenizer
90
+ try:
91
+ print("Attempting regular fine-tuned model loading...")
92
+
93
+ # Use base model tokenizer (often works better for fine-tuned models)
94
+ print("Loading tokenizer from base model...")
95
+ description_tokenizer = AutoTokenizer.from_pretrained(base_model)
96
+
97
+ print("Loading fine-tuned model...")
98
+ description_model = AutoModelForCausalLM.from_pretrained(
99
+ fine_tuned_model,
100
+ torch_dtype=torch.float16,
101
+ device_map="auto",
102
+ trust_remote_code=True
103
+ )
104
+ print("✅ Successfully loaded as regular fine-tuned model!")
105
+
106
+ except Exception as e2:
107
+ print(f"Regular fine-tuned loading failed: {e2}")
108
+
109
+ # Method 3: Load base model only (as fallback)
110
+ print("Loading base model as fallback...")
111
+ description_tokenizer = AutoTokenizer.from_pretrained(base_model)
112
+ description_model = AutoModelForCausalLM.from_pretrained(
113
+ base_model,
114
+ torch_dtype=torch.float16,
115
+ device_map="auto"
116
+ )
117
+ print("⚠️ Loaded base model only - fine-tuning not applied!")
118
 
119
  # Set up tokenizer padding
120
  if description_tokenizer.pad_token is None:
121
  description_tokenizer.pad_token = description_tokenizer.eos_token
122
 
123
  print("✅ Model loading completed successfully!")
124
+
125
  except Exception as e:
126
  print(f"❌ All loading methods failed: {e}")
127
  raise e
 
150
  "description_tokenizer": description_tokenizer is not None
151
  }
152
  return {
153
+ "status": "healthy" if all(model_status.values()) else "partial",
154
  "models": model_status
155
  }
156
 
 
158
  async def generate_description(item: DescriptionItem):
159
  if description_model is None or description_tokenizer is None:
160
  raise HTTPException(
161
+ status_code=503,
162
  detail="Description model not available"
163
  )
164
+
165
  try:
166
  # Tokenize the input prompt
167
  inputs = description_tokenizer(
 
171
  truncation=True,
172
  max_length=512
173
  )
174
+
175
  # Move inputs to model device
176
  if hasattr(description_model, 'device'):
177
  inputs = {k: v.to(description_model.device) for k, v in inputs.items()}
178
+
179
  # Generate response
180
  with torch.no_grad():
181
  outputs = description_model.generate(
 
187
  pad_token_id=description_tokenizer.eos_token_id,
188
  repetition_penalty=1.1
189
  )
190
+
191
  # Decode only the new tokens
192
  input_length = inputs['input_ids'].shape[1]
193
  description = description_tokenizer.decode(
194
+ outputs[0][input_length:],
195
  skip_special_tokens=True
196
  )
197
+
198
  return {"description": description.strip()}
199
+
200
  except Exception as e:
201
  raise HTTPException(
202
+ status_code=500,
203
  detail=f"Error generating description: {str(e)}"
204
+ )