hadokenvskikoken commited on
Commit
82904d0
·
verified ·
1 Parent(s): 3e4e9a0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +51 -28
main.py CHANGED
@@ -6,17 +6,33 @@ import autopep8
6
  import subprocess
7
  import time
8
  import re
 
 
9
 
10
  app = FastAPI(title="Code Evaluation & Optimization API")
11
 
 
 
 
 
 
 
12
  # --- Load AI Model ---
13
  MODEL_NAME = "codellama/CodeLlama-7b-hf"
14
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
- model = AutoModelForCausalLM.from_pretrained(
16
- MODEL_NAME,
17
- device_map="auto",
18
- torch_dtype=torch.float16
19
- )
 
 
 
 
 
 
 
 
20
 
21
  # --- Request Models ---
22
  class CodeRequest(BaseModel):
@@ -105,27 +121,30 @@ def evaluate_code(user_code: str, lang: str) -> dict:
105
 
106
  def optimize_code_ai(user_code: str, lang: str) -> str:
107
  """Generate optimized code using AI and formatting"""
108
- # Basic formatting first
109
- if lang == "python":
110
- user_code = autopep8.fix_code(user_code)
111
- user_code = re.sub(r"eval\((.*)\)", r"int(\1) # Removed eval for security", user_code)
112
- user_code = re.sub(r"/ 0", "/ 1 # Fixed division by zero", user_code)
113
-
114
- # AI-powered optimization
115
- prompt = f"Optimize this {lang} code for efficiency and security:\n```{lang}\n{user_code}\n```\nOptimized version:"
116
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
117
-
118
- with torch.no_grad():
119
- outputs = model.generate(**inputs, max_length=1024)
120
-
121
- optimized_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
122
-
123
- # Extract just the code block if LLM added explanation
124
- code_match = re.search(r'```(?:python)?\n(.*?)\n```', optimized_code, re.DOTALL)
125
- if code_match:
126
- optimized_code = code_match.group(1)
127
-
128
- return optimized_code if optimized_code else user_code
 
 
 
129
 
130
  # --- API Endpoints ---
131
  @app.post("/evaluate")
@@ -148,7 +167,11 @@ async def optimize_endpoint(request: CodeRequest):
148
 
149
  @app.get("/")
150
  def health_check():
151
- return {"status": "Code Evaluation API is running!"}
 
 
 
 
152
 
153
  # For local testing
154
  if __name__ == "__main__":
 
6
  import subprocess
7
  import time
8
  import re
9
+ import os
10
+ from pathlib import Path
11
 
12
  app = FastAPI(title="Code Evaluation & Optimization API")
13
 
14
+ # --- Environment Setup ---
15
+ CACHE_DIR = Path("/.cache/huggingface")
16
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
17
+ os.environ["TRANSFORMERS_CACHE"] = str(CACHE_DIR)
18
+ os.environ["HF_HOME"] = str(CACHE_DIR)
19
+
20
  # --- Load AI Model ---
21
  MODEL_NAME = "codellama/CodeLlama-7b-hf"
22
+
23
+ try:
24
+ tokenizer = AutoTokenizer.from_pretrained(
25
+ MODEL_NAME,
26
+ cache_dir=str(CACHE_DIR)
27
+ )
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ MODEL_NAME,
30
+ device_map="auto",
31
+ torch_dtype=torch.float16,
32
+ cache_dir=str(CACHE_DIR)
33
+ )
34
+ except Exception as e:
35
+ raise RuntimeError(f"Failed to load model: {str(e)}")
36
 
37
  # --- Request Models ---
38
  class CodeRequest(BaseModel):
 
121
 
122
  def optimize_code_ai(user_code: str, lang: str) -> str:
123
  """Generate optimized code using AI and formatting"""
124
+ try:
125
+ # Basic formatting first
126
+ if lang == "python":
127
+ user_code = autopep8.fix_code(user_code)
128
+ user_code = re.sub(r"eval\((.*)\)", r"int(\1) # Removed eval for security", user_code)
129
+ user_code = re.sub(r"/ 0", "/ 1 # Fixed division by zero", user_code)
130
+
131
+ # AI-powered optimization
132
+ prompt = f"Optimize this {lang} code for efficiency and security:\n```{lang}\n{user_code}\n```\nOptimized version:"
133
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
134
+
135
+ with torch.no_grad():
136
+ outputs = model.generate(**inputs, max_length=1024)
137
+
138
+ optimized_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
139
+
140
+ # Extract just the code block if LLM added explanation
141
+ code_match = re.search(r'```(?:python)?\n(.*?)\n```', optimized_code, re.DOTALL)
142
+ if code_match:
143
+ optimized_code = code_match.group(1)
144
+
145
+ return optimized_code if optimized_code else user_code
146
+ except Exception as e:
147
+ raise HTTPException(status_code=500, detail=f"AI optimization failed: {str(e)}")
148
 
149
  # --- API Endpoints ---
150
  @app.post("/evaluate")
 
167
 
168
  @app.get("/")
169
  def health_check():
170
+ return {
171
+ "status": "Code Evaluation API is running!",
172
+ "model_loaded": MODEL_NAME,
173
+ "cache_dir": str(CACHE_DIR)
174
+ }
175
 
176
  # For local testing
177
  if __name__ == "__main__":