shaheerawan3 commited on
Commit
0bd16b4
·
verified ·
1 Parent(s): 70d7253

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +53 -128
model.py CHANGED
@@ -1,14 +1,10 @@
1
  # model.py
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import torch
4
- from typing import List, Dict, Optional
5
  import logging
6
- from functools import lru_cache
7
- import gc
8
- import ast
9
- import numpy as np
10
  from dataclasses import dataclass
11
- import time
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
@@ -22,177 +18,106 @@ class CodeAnalysis:
22
  performance_tips: List[str]
23
 
24
  class CodeTeachingAssistant:
25
- def __init__(self, model_name: str = "codellama/CodeLlama-7b-hf"):
26
  self.model_name = model_name
27
  self._initialize_model()
28
 
29
  def _initialize_model(self) -> None:
30
- """Initialize model with fallback to CPU if GPU not available."""
31
  try:
32
  logger.info(f"Loading model: {self.model_name}")
33
 
34
- # Load tokenizer
35
  self.tokenizer = AutoTokenizer.from_pretrained(
36
  self.model_name,
37
- trust_remote_code=True
38
  )
39
 
40
- # Load model with CPU compatibility
41
  self.model = AutoModelForCausalLM.from_pretrained(
42
  self.model_name,
43
- trust_remote_code=True,
44
- device_map="cpu", # Force CPU usage
45
- low_cpu_mem_usage=True
46
  )
47
 
48
  self.pipe = pipeline(
49
  "text-generation",
50
  model=self.model,
51
  tokenizer=self.tokenizer,
52
- max_length=512,
53
  temperature=0.7,
54
- device="cpu" # Ensure CPU usage
55
  )
56
 
57
- logger.info("Model loaded successfully on CPU")
58
 
59
  except Exception as e:
60
  logger.error(f"Error loading model: {str(e)}")
61
  raise
62
 
63
  def analyze_code_quality(self, code: str) -> CodeAnalysis:
64
- """Analyze code quality metrics and patterns."""
65
  try:
66
  tree = ast.parse(code)
67
 
68
- # Calculate cyclomatic complexity
69
  complexity = self._calculate_complexity(tree)
70
 
71
- # Identify design patterns
72
- patterns = self._identify_patterns(tree)
73
-
74
- # Generate improvement suggestions
75
- suggestions = self._generate_suggestions(tree)
76
-
77
- # Check for security issues
78
- security_issues = self._check_security(tree)
79
-
80
- # Generate performance tips
81
- performance_tips = self._analyze_performance(tree)
82
-
83
  return CodeAnalysis(
84
  complexity=complexity,
85
- patterns=patterns,
86
- suggestions=suggestions,
87
- security_issues=security_issues,
88
- performance_tips=performance_tips
89
  )
90
  except Exception as e:
91
  logger.error(f"Error in code analysis: {str(e)}")
92
  return None
93
 
94
- def generate_test_cases(self, code: str) -> List[Dict]:
95
- """Generate comprehensive test cases for the code."""
96
- try:
97
- tree = ast.parse(code)
98
- test_cases = []
99
-
100
- # Analyze function signatures
101
- for node in ast.walk(tree):
102
- if isinstance(node, ast.FunctionDef):
103
- # Generate edge cases
104
- edge_cases = self._generate_edge_cases(node)
105
- # Generate boundary cases
106
- boundary_cases = self._generate_boundary_cases(node)
107
- # Generate typical cases
108
- typical_cases = self._generate_typical_cases(node)
109
-
110
- test_cases.extend(edge_cases + boundary_cases + typical_cases)
111
-
112
- return test_cases
113
- except Exception as e:
114
- logger.error(f"Error generating test cases: {str(e)}")
115
- return []
116
 
117
- def interactive_debugging(self, code: str) -> List[Dict]:
118
- """Provide interactive debugging suggestions."""
119
- try:
120
- issues = []
121
- tree = ast.parse(code)
122
-
123
- # Check for common bugs
124
- issues.extend(self._check_common_bugs(tree))
125
-
126
- # Check for logic errors
127
- issues.extend(self._check_logic_errors(tree))
128
-
129
- # Generate fix suggestions
130
- for issue in issues:
131
- issue['fix_suggestion'] = self._generate_fix(issue)
132
-
133
- return issues
134
- except Exception as e:
135
- logger.error(f"Error in debugging: {str(e)}")
136
- return []
137
 
138
  def learning_path_generator(self, code: str, user_level: str) -> Dict:
139
- """Generate personalized learning path based on code analysis."""
140
- concepts = self.identify_concepts(code)
141
  return {
142
- 'current_level': self._assess_code_level(code),
143
- 'concepts_to_learn': self._identify_learning_gaps(concepts, user_level),
144
- 'recommended_exercises': self._generate_exercises(concepts, user_level),
145
- 'learning_resources': self._recommend_resources(concepts, user_level),
146
- 'estimated_timeline': self._generate_timeline(concepts, user_level)
147
  }
148
 
149
  def real_time_pair_programming(self, code_stream: str) -> Dict:
150
- """Provide real-time suggestions during coding."""
151
  return {
152
- 'auto_completion': self._generate_completion(code_stream),
153
- 'style_suggestions': self._check_style(code_stream),
154
- 'optimization_hints': self._suggest_optimizations(code_stream),
155
- 'documentation_hints': self._suggest_documentation(code_stream)
156
  }
157
 
158
  def code_review_assistant(self, code: str) -> Dict:
159
- """Provide comprehensive code review."""
160
  return {
161
- 'style_issues': self._check_coding_style(code),
162
- 'best_practices': self._check_best_practices(code),
163
- 'maintainability_score': self._calculate_maintainability(code),
164
- 'suggested_refactoring': self._suggest_refactoring(code),
165
- 'documentation_quality': self._assess_documentation(code)
166
- }
167
-
168
- # Helper methods for code analysis
169
- def _calculate_complexity(self, tree: ast.AST) -> float:
170
- complexity = 0
171
- for node in ast.walk(tree):
172
- if isinstance(node, (ast.If, ast.While, ast.For, ast.FunctionDef)):
173
- complexity += 1
174
- return complexity
175
-
176
- def _identify_patterns(self, tree: ast.AST) -> List[str]:
177
- patterns = []
178
- # Add pattern recognition logic
179
- return patterns
180
-
181
- def _generate_suggestions(self, tree: ast.AST) -> List[str]:
182
- suggestions = []
183
- # Add suggestion generation logic
184
- return suggestions
185
-
186
- def _check_security(self, tree: ast.AST) -> List[str]:
187
- issues = []
188
- # Add security check logic
189
- return issues
190
-
191
- def _analyze_performance(self, tree: ast.AST) -> List[str]:
192
- tips = []
193
- # Add performance analysis logic
194
- return tips
195
-
196
- def _generate_completion(self, code_stream: str) -> str:
197
- prompt = f"Complete this code:\n{code_stream}"
198
- return self.pipe(prompt)[0]['generated_text']
 
1
  # model.py
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import torch
4
+ from typing import List, Dict
5
  import logging
 
 
 
 
6
  from dataclasses import dataclass
7
+ import ast
8
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
 
18
  performance_tips: List[str]
19
 
20
  class CodeTeachingAssistant:
21
+ def __init__(self, model_name: str = "gpt2"): # Using smaller model for faster loading
22
  self.model_name = model_name
23
  self._initialize_model()
24
 
25
  def _initialize_model(self) -> None:
26
+ """Initialize model optimized for Hugging Face spaces."""
27
  try:
28
  logger.info(f"Loading model: {self.model_name}")
29
 
30
+ # Load tokenizer with fast tokenization
31
  self.tokenizer = AutoTokenizer.from_pretrained(
32
  self.model_name,
33
+ use_fast=True
34
  )
35
 
36
+ # Load model with optimizations
37
  self.model = AutoModelForCausalLM.from_pretrained(
38
  self.model_name,
39
+ low_cpu_mem_usage=True,
40
+ torch_dtype=torch.float16 # Use half precision
 
41
  )
42
 
43
  self.pipe = pipeline(
44
  "text-generation",
45
  model=self.model,
46
  tokenizer=self.tokenizer,
47
+ max_length=256, # Reduced max length
48
  temperature=0.7,
 
49
  )
50
 
51
+ logger.info("Model loaded successfully")
52
 
53
  except Exception as e:
54
  logger.error(f"Error loading model: {str(e)}")
55
  raise
56
 
57
  def analyze_code_quality(self, code: str) -> CodeAnalysis:
58
+ """Simplified code analysis for better performance."""
59
  try:
60
  tree = ast.parse(code)
61
 
62
+ # Basic complexity calculation
63
  complexity = self._calculate_complexity(tree)
64
 
65
+ # Simplified analysis
 
 
 
 
 
 
 
 
 
 
 
66
  return CodeAnalysis(
67
  complexity=complexity,
68
+ patterns=["Basic patterns analysis"],
69
+ suggestions=["Keep functions small", "Add comments for clarity"],
70
+ security_issues=["Review input validation"],
71
+ performance_tips=["Consider caching results"]
72
  )
73
  except Exception as e:
74
  logger.error(f"Error in code analysis: {str(e)}")
75
  return None
76
 
77
+ def _calculate_complexity(self, tree: ast.AST) -> float:
78
+ """Calculate basic cyclomatic complexity."""
79
+ complexity = 0
80
+ for node in ast.walk(tree):
81
+ if isinstance(node, (ast.If, ast.While, ast.For, ast.FunctionDef)):
82
+ complexity += 1
83
+ return complexity
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ def generate_test_cases(self, code: str) -> List[Dict]:
86
+ """Generate basic test cases."""
87
+ return [
88
+ {
89
+ 'name': 'Basic Test',
90
+ 'code': 'def test_basic(): pass',
91
+ 'purpose': 'Basic functionality test',
92
+ 'expected_output': 'None'
93
+ }
94
+ ]
 
 
 
 
 
 
 
 
 
 
95
 
96
  def learning_path_generator(self, code: str, user_level: str) -> Dict:
97
+ """Generate simplified learning path."""
 
98
  return {
99
+ 'current_level': user_level,
100
+ 'concepts_to_learn': ['Basic Programming', 'Code Organization'],
101
+ 'recommended_exercises': ['Practice basic algorithms'],
102
+ 'learning_resources': ['Official documentation'],
103
+ 'estimated_timeline': '2-4 weeks'
104
  }
105
 
106
  def real_time_pair_programming(self, code_stream: str) -> Dict:
107
+ """Provide basic real-time suggestions."""
108
  return {
109
+ 'auto_completion': code_stream + "\n pass",
110
+ 'style_suggestions': ['Use consistent indentation'],
111
+ 'optimization_hints': ['Consider using built-in functions'],
112
+ 'documentation_hints': ['Add docstrings to functions']
113
  }
114
 
115
  def code_review_assistant(self, code: str) -> Dict:
116
+ """Provide basic code review."""
117
  return {
118
+ 'style_issues': ['Check PEP 8 compliance'],
119
+ 'best_practices': ['Write descriptive variable names'],
120
+ 'maintainability_score': 75,
121
+ 'suggested_refactoring': ['Extract complex logic into functions'],
122
+ 'documentation_quality': ['Add more inline comments']
123
+ }