soxogvv commited on
Commit
605e76c
·
verified ·
1 Parent(s): 02e2ce3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -33
app.py CHANGED
@@ -41,30 +41,56 @@ class CodeLlamaService:
41
  # Use the smallest Code Llama model that fits in 16GB
42
  model_name = "codellama/CodeLlama-7b-Instruct-hf"
43
 
44
- # Load with memory optimization
 
 
 
 
45
  self.tokenizer = AutoTokenizer.from_pretrained(
46
  model_name,
47
  use_fast=True,
48
  trust_remote_code=True
49
  )
50
 
51
- # Load model with optimizations for CPU inference
52
- self.model = AutoModelForCausalLM.from_pretrained(
53
- model_name,
54
- torch_dtype=torch.float16,
55
- low_cpu_mem_usage=True,
56
- trust_remote_code=True,
57
- device_map="auto"
58
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # Create pipeline
61
- self.pipeline = pipeline(
62
- "text-generation",
63
- model=self.model,
64
- tokenizer=self.tokenizer,
65
- torch_dtype=torch.float16,
66
- device_map="auto"
67
- )
 
 
 
 
 
 
 
 
68
 
69
  self.is_loaded = True
70
  logger.info("Model loaded successfully!")
@@ -72,6 +98,16 @@ class CodeLlamaService:
72
  except Exception as e:
73
  logger.error(f"Error loading model: {str(e)}")
74
  self.is_loaded = False
 
 
 
 
 
 
 
 
 
 
75
  finally:
76
  self.is_loading = False
77
 
@@ -84,21 +120,32 @@ class CodeLlamaService:
84
  # Format prompt for instruction following
85
  formatted_prompt = f"<s>[INST] {prompt} [/INST]"
86
 
87
- # Generate response
88
- outputs = self.pipeline(
89
- formatted_prompt,
90
- max_length=max_length,
91
- temperature=temperature,
92
- do_sample=True,
93
- top_p=0.9,
94
- repetition_penalty=1.1,
95
- pad_token_id=self.tokenizer.eos_token_id,
96
- eos_token_id=self.tokenizer.eos_token_id
97
- )
 
 
 
 
98
 
99
  # Extract generated text
100
- generated_text = outputs[0]['generated_text']
101
- response = generated_text[len(formatted_prompt):].strip()
 
 
 
 
 
 
 
102
 
103
  # Split response into code and explanation if possible
104
  code, explanation = self._parse_response(response)
@@ -142,12 +189,27 @@ class CodeLlamaService:
142
  code_lines = []
143
  explanation_lines = []
144
 
 
145
  for line in lines:
146
- if (line.strip().startswith(('def ', 'class ', 'import ', 'from ', 'if ', 'for ', 'while ', ' ', '\t')) or
147
- '=' in line and not line.strip().startswith('#')):
 
 
148
  code_lines.append(line)
 
 
 
149
  else:
150
- explanation_lines.append(line)
 
 
 
 
 
 
 
 
 
151
 
152
  code = '\n'.join(code_lines)
153
  explanation = '\n'.join(explanation_lines)
 
41
  # Use the smallest Code Llama model that fits in 16GB
42
  model_name = "codellama/CodeLlama-7b-Instruct-hf"
43
 
44
+ # Check if CUDA is available
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ logger.info(f"Using device: {device}")
47
+
48
+ # Load tokenizer
49
  self.tokenizer = AutoTokenizer.from_pretrained(
50
  model_name,
51
  use_fast=True,
52
  trust_remote_code=True
53
  )
54
 
55
+ # Configure model loading based on device
56
+ if device == "cuda":
57
+ # GPU: Use float16 for memory efficiency
58
+ self.model = AutoModelForCausalLM.from_pretrained(
59
+ model_name,
60
+ torch_dtype=torch.float16,
61
+ low_cpu_mem_usage=True,
62
+ trust_remote_code=True,
63
+ device_map="auto"
64
+ )
65
+ torch_dtype = torch.float16
66
+ else:
67
+ # CPU: Use float32 to avoid Half precision errors
68
+ self.model = AutoModelForCausalLM.from_pretrained(
69
+ model_name,
70
+ torch_dtype=torch.float32,
71
+ low_cpu_mem_usage=True,
72
+ trust_remote_code=True
73
+ )
74
+ # Move model to CPU explicitly
75
+ self.model = self.model.to('cpu')
76
+ torch_dtype = torch.float32
77
 
78
+ # Create pipeline with appropriate settings
79
+ if device == "cuda":
80
+ self.pipeline = pipeline(
81
+ "text-generation",
82
+ model=self.model,
83
+ tokenizer=self.tokenizer,
84
+ torch_dtype=torch_dtype,
85
+ device=0 # GPU device
86
+ )
87
+ else:
88
+ self.pipeline = pipeline(
89
+ "text-generation",
90
+ model=self.model,
91
+ tokenizer=self.tokenizer,
92
+ device=-1 # CPU device
93
+ )
94
 
95
  self.is_loaded = True
96
  logger.info("Model loaded successfully!")
 
98
  except Exception as e:
99
  logger.error(f"Error loading model: {str(e)}")
100
  self.is_loaded = False
101
+ # Clean up on failure
102
+ if hasattr(self, 'model') and self.model is not None:
103
+ del self.model
104
+ if hasattr(self, 'tokenizer') and self.tokenizer is not None:
105
+ del self.tokenizer
106
+ if hasattr(self, 'pipeline') and self.pipeline is not None:
107
+ del self.pipeline
108
+ gc.collect()
109
+ if torch.cuda.is_available():
110
+ torch.cuda.empty_cache()
111
  finally:
112
  self.is_loading = False
113
 
 
120
  # Format prompt for instruction following
121
  formatted_prompt = f"<s>[INST] {prompt} [/INST]"
122
 
123
+ # Generate response with error handling
124
+ generation_kwargs = {
125
+ "max_new_tokens": max_length,
126
+ "do_sample": True if temperature > 0 else False,
127
+ "temperature": temperature if temperature > 0 else None,
128
+ "top_p": 0.9 if temperature > 0 else None,
129
+ "repetition_penalty": 1.1,
130
+ "return_full_text": False,
131
+ "pad_token_id": self.tokenizer.eos_token_id
132
+ }
133
+
134
+ # Remove None values to avoid warnings
135
+ generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
136
+
137
+ outputs = self.pipeline(formatted_prompt, **generation_kwargs)
138
 
139
  # Extract generated text
140
+ if isinstance(outputs, list) and len(outputs) > 0:
141
+ if 'generated_text' in outputs[0]:
142
+ response = outputs[0]['generated_text']
143
+ else:
144
+ response = str(outputs[0])
145
+ else:
146
+ response = str(outputs)
147
+
148
+ response = response.strip()
149
 
150
  # Split response into code and explanation if possible
151
  code, explanation = self._parse_response(response)
 
189
  code_lines = []
190
  explanation_lines = []
191
 
192
+ in_code_block = False
193
  for line in lines:
194
+ # Simple heuristic to detect code vs explanation
195
+ if (line.strip().startswith(('def ', 'class ', 'import ', 'from ', 'if ', 'for ', 'while ', 'function', 'var ', 'let ', 'const ')) or
196
+ line.startswith((' ', '\t')) or
197
+ ('=' in line and not line.strip().startswith('#') and not line.strip().startswith('//'))):
198
  code_lines.append(line)
199
+ in_code_block = True
200
+ elif in_code_block and line.strip() == '':
201
+ code_lines.append(line) # Keep empty lines in code blocks
202
  else:
203
+ if in_code_block and line.strip():
204
+ # Check if this line looks like code or explanation
205
+ if any(char in line for char in ['{', '}', ';', '()', '[]']) and not line.strip().endswith('.'):
206
+ code_lines.append(line)
207
+ else:
208
+ explanation_lines.append(line)
209
+ in_code_block = False
210
+ else:
211
+ explanation_lines.append(line)
212
+ in_code_block = False
213
 
214
  code = '\n'.join(code_lines)
215
  explanation = '\n'.join(explanation_lines)