Qgv1 / app.py
Arie1L's picture
Update app.py
cd4715b verified
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer
app = FastAPI()
# Question generation model
qg_model_name = "fares7elsadek/t5-base-finetuned-question-generation"
qg_tokenizer = AutoTokenizer.from_pretrained(qg_model_name, use_fast=False)
qg_model = AutoModelForSeq2SeqLM.from_pretrained(qg_model_name)
# Distractor model
distractor_model_path = "t5_distractor"
distractor_tokenizer = T5Tokenizer.from_pretrained(distractor_model_path, use_fast=False)
distractor_model = T5ForConditionalGeneration.from_pretrained(distractor_model_path)
# Function to generate question
def generate_question(context, answer="[MASK]", max_length=64):
input_text = f"context: {context} answer: {answer} </s>"
inputs = qg_tokenizer([input_text], return_tensors="pt", truncation=True, padding=True)
outputs = qg_model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length
)
return qg_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Function to generate distractors
def generate_distractors(answer, max_length=64, num_beams=5, num_return_sequences=5, num_distractors=3):
prompt = f"generate 3 similar words or numbers for \"{answer}\""
inputs = distractor_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length)
outputs = distractor_model.generate(
**inputs,
max_length=max_length,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
no_repeat_ngram_size=2,
early_stopping=True
)
decoded = [distractor_tokenizer.decode(o, skip_special_tokens=True) for o in outputs]
candidates = []
for seq in decoded:
for item in seq.split(","):
item = item.strip()
if item and item not in candidates:
candidates.append(item)
return candidates[:num_distractors]
@app.post("/generate")
async def generate(request: Request):
# Get plain text input from Unity
data = await request.body()
user_input = data.decode("utf-8")
# Parse multiple manually labeled context/answer pairs
pairs = []
lines = user_input.splitlines()
i = 0
while i < len(lines):
if lines[i].startswith("context ="):
context = lines[i].split("=", 1)[1].strip().strip('"')
i += 1
if i < len(lines) and lines[i].startswith("answer ="):
answer = lines[i].split("=", 1)[1].strip().strip('"')
pairs.append((context, answer))
i += 1
# Generate questions and distractors
output = []
for context, answer in pairs:
question_text = generate_question(context, answer)
distractors = generate_distractors(answer)
output.append({
"question": question_text,
"answer": answer,
"distractors": distractors
})
return JSONResponse(content=output)