Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import random | |
| import re | |
| from typing import List, Dict, Any, Optional | |
| app = FastAPI(title="CodeGen Kids Tutor API") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # For production, specify your frontend domain | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Model loading | |
| print("Loading model and tokenizer...") | |
| MODEL_NAME = "AhmedMOstaFA10/codegen-kids-tutor" | |
| tokenizer = None | |
| model = None | |
| class ProblemRequest(BaseModel): | |
| category: Optional[str] = None # Optional category to filter problem prompts | |
| class SolutionRequest(BaseModel): | |
| code: str | |
| reference_code: str | |
| # Problem prompts categorized by topic | |
| problem_prompts = { | |
| "arithmetic": [ | |
| "# Instruction:\nGenerate a simple arithmetic problem suitable for a kid. Write a function with a short docstring and partial code.\n\n" | |
| "# Input:\nAddition, subtraction, or multiplication\n\n" | |
| "# Solution:\n" | |
| ], | |
| "strings": [ | |
| "# Instruction:\nGenerate a basic string manipulation exercise suitable for a beginner. Write a function with a short docstring and partial code.\n\n" | |
| "# Input:\nA string operation like reversing, counting characters, or checking substrings\n\n" | |
| "# Solution:\n" | |
| ], | |
| "lists": [ | |
| "# Instruction:\nGenerate a simple list-related problem for beginners. Write a function with a short docstring and partial implementation.\n\n" | |
| "# Input:\nSorting a list, finding max or min, or summing numbers\n\n" | |
| "# Solution:\n" | |
| ], | |
| "conditions": [ | |
| "# Instruction:\nGenerate a basic Python problem using if-else conditions. Write a function with a docstring and a few lines of partial code.\n\n" | |
| "# Input:\nAge check, number comparison, or grade classification\n\n" | |
| "# Solution:\n" | |
| ], | |
| "loops": [ | |
| "# Instruction:\nCreate a beginner-friendly problem that uses a for loop. Write a function with a clear docstring and partial implementation.\n\n" | |
| "# Input:\nSumming numbers, iterating over lists, or counting even numbers\n\n" | |
| "# Solution:\n", | |
| "# Instruction:\nWrite a basic programming problem involving a while loop. Include a function definition, a short docstring, and partial implementation.\n\n" | |
| "# Input:\nRepeating until condition is met, counting, or basic input validation\n\n" | |
| "# Solution:\n" | |
| ], | |
| "dictionaries": [ | |
| "# Instruction:\nGenerate an easy dictionary-based Python exercise. Write a function with a short docstring and partial implementation.\n\n" | |
| "# Input:\nAccessing values, summing values, or checking keys in a dictionary\n\n" | |
| "# Solution:\n" | |
| ], | |
| "input_output": [ | |
| "# Instruction:\nWrite a problem simulating user input and output in Python. Provide a function with a docstring and a few lines of implementation.\n\n" | |
| "# Input:\nName, age, or favorite color, and return a formatted string\n\n" | |
| "# Solution:\n" | |
| ], | |
| "math": [ | |
| "# Instruction:\nGenerate a Python problem that implements a basic math formula. Include a function with a docstring and partial code.\n\n" | |
| "# Input:\nArea of circle, BMI calculation, or temperature conversion\n\n" | |
| "# Solution:\n" | |
| ], | |
| "boolean": [ | |
| "# Instruction:\nCreate a beginner-friendly Python exercise using boolean logic. Write a function with a docstring and partial implementation.\n\n" | |
| "# Input:\nCheck conditions like even AND positive, or NOT equal to zero\n\n" | |
| "# Solution:\n" | |
| ] | |
| } | |
| # Get all prompts in a single list for random selection | |
| all_prompts = [] | |
| for category_prompts in problem_prompts.values(): | |
| all_prompts.extend(category_prompts) | |
| async def startup_event(): | |
| global tokenizer, model | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| # Check for GPU availability | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| print(f"Model loaded successfully on {device}") | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| # We'll initialize lazily if this fails on startup | |
| def get_model(): | |
| global tokenizer, model | |
| if tokenizer is None or model is None: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| # Check for GPU availability | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| return tokenizer, model | |
| def generate_full_solution(prompt): | |
| tokenizer, model = get_model() | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| inputs["input_ids"], | |
| max_length=256, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| top_p=0.95, | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id | |
| ) | |
| full_solution = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
| return full_solution | |
| def truncate_function_body(code): | |
| lines = code.strip().split('\n') | |
| truncated = [] | |
| for line in lines: | |
| stripped = line.strip() | |
| truncated.append(line) | |
| if stripped.startswith('return') or stripped.startswith('print'): | |
| break | |
| if len(truncated) >= 4: | |
| break | |
| return '\n'.join(truncated) | |
| def read_root(): | |
| return {"message": "CodeGen Kids Tutor API is running!"} | |
| def generate_problem(request: ProblemRequest): | |
| try: | |
| # Select prompts based on category if provided | |
| selected_prompts = [] | |
| if request.category and request.category in problem_prompts: | |
| selected_prompts = problem_prompts[request.category] | |
| else: | |
| selected_prompts = all_prompts | |
| if not selected_prompts: | |
| raise HTTPException(status_code=400, detail="No problem prompts available for the selected category") | |
| problem_prompt = random.choice(selected_prompts) | |
| complete_solution = generate_full_solution(problem_prompt) | |
| # Extract problem statement and function code | |
| split = complete_solution.strip().split('\n') | |
| problem_lines = [] | |
| function_lines = [] | |
| for line in split: | |
| if line.strip().startswith("def ") or line.strip().startswith('"""') or line.strip().startswith("#"): | |
| function_lines.append(line) | |
| else: | |
| problem_lines.append(line) | |
| current_problem = '\n'.join(problem_lines[:2]).strip() | |
| truncated_solution = truncate_function_body('\n'.join(function_lines)) | |
| return { | |
| "problem": current_problem, | |
| "starter_code": truncated_solution, | |
| "reference_code": truncated_solution # For verification later | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error generating problem: {str(e)}") | |
| def check_solution(request: SolutionRequest): | |
| try: | |
| user_solution = request.code.strip() | |
| reference_code = request.reference_code.strip() | |
| # Basic syntax check | |
| try: | |
| compile(user_solution, '<string>', 'exec') | |
| except Exception as e: | |
| return { | |
| "is_correct": False, | |
| "feedback": f"Syntax error: {str(e)}" | |
| } | |
| # Function name check | |
| model_func_match = re.search(r'def\s+([a-zA-Z_][a-zA-Z0-9_]*)', reference_code) | |
| user_func_match = re.search(r'def\s+([a-zA-Z_][a-zA-Z0-9_]*)', user_solution) | |
| if model_func_match and user_func_match: | |
| if model_func_match.group(1) != user_func_match.group(1): | |
| return { | |
| "is_correct": False, | |
| "feedback": "You changed the function name. Keep the original function name." | |
| } | |
| # Import difflib for sequence matching to evaluate solution similarity | |
| from difflib import SequenceMatcher | |
| similarity = SequenceMatcher(None, reference_code, user_solution).ratio() | |
| if similarity > 0.5: | |
| return { | |
| "is_correct": True, | |
| "feedback": "Your solution looks correct! Great job! ๐" | |
| } | |
| elif similarity > 0.3: | |
| return { | |
| "is_correct": True, | |
| "feedback": "Your solution passes, but there might be a more efficient approach. Keep going! ๐" | |
| } | |
| else: | |
| return { | |
| "is_correct": False, | |
| "feedback": "Your solution differs significantly from the expected solution. Try again! ๐" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error checking solution: {str(e)}") | |