shaheerawan3's picture
Update model.py
0bd16b4 verified
# model.py
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
from typing import List, Dict
import logging
from dataclasses import dataclass
import ast
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class CodeAnalysis:
complexity: float
patterns: List[str]
suggestions: List[str]
security_issues: List[str]
performance_tips: List[str]
class CodeTeachingAssistant:
def __init__(self, model_name: str = "gpt2"): # Using smaller model for faster loading
self.model_name = model_name
self._initialize_model()
def _initialize_model(self) -> None:
"""Initialize model optimized for Hugging Face spaces."""
try:
logger.info(f"Loading model: {self.model_name}")
# Load tokenizer with fast tokenization
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
use_fast=True
)
# Load model with optimizations
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
low_cpu_mem_usage=True,
torch_dtype=torch.float16 # Use half precision
)
self.pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
max_length=256, # Reduced max length
temperature=0.7,
)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
def analyze_code_quality(self, code: str) -> CodeAnalysis:
"""Simplified code analysis for better performance."""
try:
tree = ast.parse(code)
# Basic complexity calculation
complexity = self._calculate_complexity(tree)
# Simplified analysis
return CodeAnalysis(
complexity=complexity,
patterns=["Basic patterns analysis"],
suggestions=["Keep functions small", "Add comments for clarity"],
security_issues=["Review input validation"],
performance_tips=["Consider caching results"]
)
except Exception as e:
logger.error(f"Error in code analysis: {str(e)}")
return None
def _calculate_complexity(self, tree: ast.AST) -> float:
"""Calculate basic cyclomatic complexity."""
complexity = 0
for node in ast.walk(tree):
if isinstance(node, (ast.If, ast.While, ast.For, ast.FunctionDef)):
complexity += 1
return complexity
def generate_test_cases(self, code: str) -> List[Dict]:
"""Generate basic test cases."""
return [
{
'name': 'Basic Test',
'code': 'def test_basic(): pass',
'purpose': 'Basic functionality test',
'expected_output': 'None'
}
]
def learning_path_generator(self, code: str, user_level: str) -> Dict:
"""Generate simplified learning path."""
return {
'current_level': user_level,
'concepts_to_learn': ['Basic Programming', 'Code Organization'],
'recommended_exercises': ['Practice basic algorithms'],
'learning_resources': ['Official documentation'],
'estimated_timeline': '2-4 weeks'
}
def real_time_pair_programming(self, code_stream: str) -> Dict:
"""Provide basic real-time suggestions."""
return {
'auto_completion': code_stream + "\n pass",
'style_suggestions': ['Use consistent indentation'],
'optimization_hints': ['Consider using built-in functions'],
'documentation_hints': ['Add docstrings to functions']
}
def code_review_assistant(self, code: str) -> Dict:
"""Provide basic code review."""
return {
'style_issues': ['Check PEP 8 compliance'],
'best_practices': ['Write descriptive variable names'],
'maintainability_score': 75,
'suggested_refactoring': ['Extract complex logic into functions'],
'documentation_quality': ['Add more inline comments']
}