Fayza38's picture
Upload 2 files
b7be4fe verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from TextToSpeech import text_to_speech
import torch
import base64
# =========================================
# ENUM MAPPINGS (Match Backend Enums)
# =========================================
SESSION_TYPES = {
1: "technical",
2: "softskills"
}
TRACKS = {
19: "generalprogramming"
}
# =========================================
# LOAD MODEL ONCE (Global)
# =========================================
MODEL_PATH = "Fayza38/Question_and_Answer"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch.float32,
device_map="cpu"
)
app = FastAPI()
# =========================================
# REQUEST MODEL
# =========================================
class QuestionRequest(BaseModel):
sessionType: int
difficultyLevel: int | None = None
trackName: int
# =========================================
# HELPER: GENERATE TEXT USING QWEN TEMPLATE
# =========================================
def generate_from_model(prompt: str):
messages = [
{"role": "system", "content": "You are a professional interview question generator."},
{"role": "user", "content": prompt}
]
formatted_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(formatted_prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=1200,
temperature=0.7
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
return decoded
# =========================================
# PARSE Q/A FORMAT
# =========================================
def parse_qa_blocks(text: str):
blocks = text.split("\n\n")
results = []
for block in blocks:
if "Q:" in block and "A:" in block:
parts = block.split("A:")
question = parts[0].replace("Q:", "").strip()
answer = parts[1].strip()
results.append((question, answer))
return results
# =========================================
# MAIN ENDPOINT
# =========================================
@app.post("/generate-questions")
def generate_questions(request: QuestionRequest):
if request.sessionType not in SESSION_TYPES:
raise HTTPException(status_code=400, detail="Invalid session type")
session_type = SESSION_TYPES[request.sessionType]
# ---------------- SOFT SKILLS ----------------
if session_type == "softskills":
prompt = """
Generate 10 behavioral interview questions.
Format exactly as:
Q: ...
A: ...
"""
# ---------------- TECHNICAL ----------------
elif session_type == "technical":
if request.trackName not in TRACKS:
raise HTTPException(status_code=400, detail="Track not supported")
difficulty = request.difficultyLevel or 1
prompt = f"""
Generate 10 General Programming interview questions.
Difficulty level: {difficulty}
Format exactly as:
Q: ...
A: ...
"""
else:
raise HTTPException(status_code=400, detail="Invalid session type")
# -------- Generate once --------
raw_output = generate_from_model(prompt)
qa_pairs = parse_qa_blocks(raw_output)
if len(qa_pairs) == 0:
raise HTTPException(status_code=500, detail="Model failed to generate valid Q/A format")
response = []
for idx, (question, answer) in enumerate(qa_pairs[:10], 1):
response.append({
"questionText": question,
"questionId": idx,
"questionIdealAnswer": answer
})
return response