| from fastapi import FastAPI, Request |
| from fastapi.responses import JSONResponse |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer |
|
|
| app = FastAPI() |
|
|
| |
| qg_model_name = "fares7elsadek/t5-base-finetuned-question-generation" |
| qg_tokenizer = AutoTokenizer.from_pretrained(qg_model_name, use_fast=False) |
| qg_model = AutoModelForSeq2SeqLM.from_pretrained(qg_model_name) |
|
|
| |
| distractor_model_path = "t5_distractor" |
| distractor_tokenizer = T5Tokenizer.from_pretrained(distractor_model_path, use_fast=False) |
| distractor_model = T5ForConditionalGeneration.from_pretrained(distractor_model_path) |
|
|
|
|
| |
| def generate_question(context, answer="[MASK]", max_length=64): |
| input_text = f"context: {context} answer: {answer} </s>" |
| inputs = qg_tokenizer([input_text], return_tensors="pt", truncation=True, padding=True) |
| outputs = qg_model.generate( |
| input_ids=inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| max_length=max_length |
| ) |
| return qg_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| def generate_distractors(answer, max_length=64, num_beams=5, num_return_sequences=5, num_distractors=3): |
| prompt = f"generate 3 similar words or numbers for \"{answer}\"" |
| inputs = distractor_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length) |
| outputs = distractor_model.generate( |
| **inputs, |
| max_length=max_length, |
| num_beams=num_beams, |
| num_return_sequences=num_return_sequences, |
| no_repeat_ngram_size=2, |
| early_stopping=True |
| ) |
| decoded = [distractor_tokenizer.decode(o, skip_special_tokens=True) for o in outputs] |
| candidates = [] |
| for seq in decoded: |
| for item in seq.split(","): |
| item = item.strip() |
| if item and item not in candidates: |
| candidates.append(item) |
| return candidates[:num_distractors] |
|
|
| @app.post("/generate") |
| async def generate(request: Request): |
| |
| data = await request.body() |
| user_input = data.decode("utf-8") |
|
|
| |
| pairs = [] |
| lines = user_input.splitlines() |
| i = 0 |
| while i < len(lines): |
| if lines[i].startswith("context ="): |
| context = lines[i].split("=", 1)[1].strip().strip('"') |
| i += 1 |
| if i < len(lines) and lines[i].startswith("answer ="): |
| answer = lines[i].split("=", 1)[1].strip().strip('"') |
| pairs.append((context, answer)) |
| i += 1 |
|
|
| |
| output = [] |
| for context, answer in pairs: |
| question_text = generate_question(context, answer) |
| distractors = generate_distractors(answer) |
| output.append({ |
| "question": question_text, |
| "answer": answer, |
| "distractors": distractors |
| }) |
|
|
| return JSONResponse(content=output) |
|
|