from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer from TextToSpeech import text_to_speech import torch import base64 # ========================================= # ENUM MAPPINGS (Match Backend Enums) # ========================================= SESSION_TYPES = { 1: "technical", 2: "softskills" } TRACKS = { 19: "generalprogramming" } # ========================================= # LOAD MODEL ONCE (Global) # ========================================= MODEL_PATH = "Fayza38/Question_and_Answer" tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, torch_dtype=torch.float32, device_map="cpu" ) app = FastAPI() # ========================================= # REQUEST MODEL # ========================================= class QuestionRequest(BaseModel): sessionType: int difficultyLevel: int | None = None trackName: int # ========================================= # HELPER: GENERATE TEXT USING QWEN TEMPLATE # ========================================= def generate_from_model(prompt: str): messages = [ {"role": "system", "content": "You are a professional interview question generator."}, {"role": "user", "content": prompt} ] formatted_prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(formatted_prompt, return_tensors="pt") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=1200, temperature=0.7 ) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) return decoded # ========================================= # PARSE Q/A FORMAT # ========================================= def parse_qa_blocks(text: str): blocks = text.split("\n\n") results = [] for block in blocks: if "Q:" in block and "A:" in block: parts = block.split("A:") question = parts[0].replace("Q:", "").strip() answer = parts[1].strip() results.append((question, answer)) return results # ========================================= # MAIN ENDPOINT # ========================================= @app.post("/generate-questions") def generate_questions(request: QuestionRequest): if request.sessionType not in SESSION_TYPES: raise HTTPException(status_code=400, detail="Invalid session type") session_type = SESSION_TYPES[request.sessionType] # ---------------- SOFT SKILLS ---------------- if session_type == "softskills": prompt = """ Generate 10 behavioral interview questions. Format exactly as: Q: ... A: ... """ # ---------------- TECHNICAL ---------------- elif session_type == "technical": if request.trackName not in TRACKS: raise HTTPException(status_code=400, detail="Track not supported") difficulty = request.difficultyLevel or 1 prompt = f""" Generate 10 General Programming interview questions. Difficulty level: {difficulty} Format exactly as: Q: ... A: ... """ else: raise HTTPException(status_code=400, detail="Invalid session type") # -------- Generate once -------- raw_output = generate_from_model(prompt) qa_pairs = parse_qa_blocks(raw_output) if len(qa_pairs) == 0: raise HTTPException(status_code=500, detail="Model failed to generate valid Q/A format") response = [] for idx, (question, answer) in enumerate(qa_pairs[:10], 1): response.append({ "questionText": question, "questionId": idx, "questionIdealAnswer": answer }) return response