Spaces:
Sleeping
Sleeping
Switch to FLAN-T5-Large: uses standard HF storage, excellent for question generation
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ from typing import List, Optional, Dict, Any
|
|
| 6 |
from contextlib import asynccontextmanager
|
| 7 |
|
| 8 |
import torch
|
| 9 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 10 |
import uvicorn
|
| 11 |
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
| 12 |
from fastapi.middleware.cors import CORSMiddleware
|
|
@@ -55,15 +55,26 @@ async def load_model_with_retry(model_name: str, hf_token: str, max_retries: int
|
|
| 55 |
token=hf_token
|
| 56 |
)
|
| 57 |
|
| 58 |
-
model
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
return tokenizer, model
|
| 69 |
|
|
@@ -101,8 +112,8 @@ async def load_model():
|
|
| 101 |
try:
|
| 102 |
logger.info("Loading model with transformers...")
|
| 103 |
|
| 104 |
-
# Use
|
| 105 |
-
base_model_name = "
|
| 106 |
|
| 107 |
tokenizer, model = await load_model_with_retry(base_model_name, hf_token)
|
| 108 |
|
|
@@ -170,42 +181,29 @@ app.add_middleware(
|
|
| 170 |
)
|
| 171 |
|
| 172 |
def create_question_prompt(statement: str, num_questions: int, difficulty_level: str) -> str:
|
| 173 |
-
"""Create a prompt for question generation
|
| 174 |
|
| 175 |
difficulty_instruction = {
|
| 176 |
-
"easy": "
|
| 177 |
-
"medium": "
|
| 178 |
-
"hard": "
|
| 179 |
-
"mixed": "
|
| 180 |
}
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
You are an expert educator and question generator. Your task is to create thoughtful, well-crafted questions from given statements."""
|
| 185 |
-
|
| 186 |
-
user_prompt = f"""<think>
|
| 187 |
-
I need to analyze this statement and generate {num_questions} high-quality questions. Let me think about:
|
| 188 |
-
1. The key concepts and information in the statement
|
| 189 |
-
2. Different types of questions I can ask (factual, analytical, inferential, evaluative)
|
| 190 |
-
3. The difficulty level requested: {difficulty_level}
|
| 191 |
-
4. How to make questions that promote understanding and critical thinking
|
| 192 |
-
</think>
|
| 193 |
-
|
| 194 |
-
Based on the following statement, generate exactly {num_questions} questions.
|
| 195 |
|
| 196 |
-
|
| 197 |
|
| 198 |
Requirements:
|
| 199 |
-
-
|
| 200 |
-
-
|
| 201 |
-
- Vary the question types (what, how, why, when, where, etc.)
|
| 202 |
-
- Each question should test different aspects of the statement
|
| 203 |
-
- Make questions engaging and thought-provoking
|
| 204 |
- Number each question (1., 2., 3., etc.)
|
|
|
|
| 205 |
|
| 206 |
-
|
| 207 |
|
| 208 |
-
return
|
| 209 |
|
| 210 |
def extract_questions(generated_text: str) -> List[str]:
|
| 211 |
"""Extract questions from the generated text"""
|
|
@@ -275,27 +273,24 @@ async def generate_questions(request: QuestionGenerationRequest):
|
|
| 275 |
)
|
| 276 |
|
| 277 |
# Generate response using transformers
|
| 278 |
-
inputs = tokenizer
|
| 279 |
if device == "cuda":
|
| 280 |
inputs = inputs.to(device)
|
| 281 |
|
| 282 |
with torch.no_grad():
|
|
|
|
| 283 |
outputs = model.generate(
|
| 284 |
-
inputs,
|
| 285 |
-
max_new_tokens=request.max_length,
|
| 286 |
temperature=request.temperature,
|
| 287 |
top_p=0.95,
|
| 288 |
-
top_k=40,
|
| 289 |
-
repetition_penalty=1.1,
|
| 290 |
do_sample=True,
|
| 291 |
-
|
| 292 |
-
|
| 293 |
)
|
| 294 |
|
| 295 |
-
# Decode the generated text
|
| 296 |
-
|
| 297 |
-
# Remove the input prompt from the response
|
| 298 |
-
generated_text = full_response[len(prompt):].strip()
|
| 299 |
logger.info(f"Generated text length: {len(generated_text)}")
|
| 300 |
|
| 301 |
# Extract questions from the generated text
|
|
@@ -313,7 +308,7 @@ async def generate_questions(request: QuestionGenerationRequest):
|
|
| 313 |
questions.append(f"What is the main point of this statement: '{request.statement[:100]}...'?")
|
| 314 |
|
| 315 |
metadata = {
|
| 316 |
-
"model": "
|
| 317 |
"temperature": request.temperature,
|
| 318 |
"difficulty_level": request.difficulty_level,
|
| 319 |
"generated_text_length": len(generated_text),
|
|
@@ -337,7 +332,7 @@ async def root():
|
|
| 337 |
"""Root endpoint with basic info"""
|
| 338 |
return {
|
| 339 |
"message": "Question Generation API",
|
| 340 |
-
"model": "
|
| 341 |
"endpoints": {
|
| 342 |
"health": "/health",
|
| 343 |
"generate": "/generate-questions",
|
|
|
|
| 6 |
from contextlib import asynccontextmanager
|
| 7 |
|
| 8 |
import torch
|
| 9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig
|
| 10 |
import uvicorn
|
| 11 |
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
| 12 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 55 |
token=hf_token
|
| 56 |
)
|
| 57 |
|
| 58 |
+
# Use Seq2Seq model for T5-based models, CausalLM for others
|
| 59 |
+
if "flan-t5" in model_name.lower() or "t5" in model_name.lower():
|
| 60 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 61 |
+
model_name,
|
| 62 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 63 |
+
device_map="auto" if device == "cuda" else None,
|
| 64 |
+
trust_remote_code=True,
|
| 65 |
+
low_cpu_mem_usage=True,
|
| 66 |
+
token=hf_token
|
| 67 |
+
)
|
| 68 |
+
else:
|
| 69 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 70 |
+
model_name,
|
| 71 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 72 |
+
device_map="auto" if device == "cuda" else None,
|
| 73 |
+
trust_remote_code=True,
|
| 74 |
+
low_cpu_mem_usage=True,
|
| 75 |
+
use_safetensors=True, # Force safetensors to avoid CVE-2025-32434
|
| 76 |
+
token=hf_token
|
| 77 |
+
)
|
| 78 |
|
| 79 |
return tokenizer, model
|
| 80 |
|
|
|
|
| 112 |
try:
|
| 113 |
logger.info("Loading model with transformers...")
|
| 114 |
|
| 115 |
+
# Use FLAN-T5 Large - excellent for question generation and uses standard HF storage
|
| 116 |
+
base_model_name = "google/flan-t5-large"
|
| 117 |
|
| 118 |
tokenizer, model = await load_model_with_retry(base_model_name, hf_token)
|
| 119 |
|
|
|
|
| 181 |
)
|
| 182 |
|
| 183 |
def create_question_prompt(statement: str, num_questions: int, difficulty_level: str) -> str:
|
| 184 |
+
"""Create a prompt for question generation optimized for T5/FLAN models"""
|
| 185 |
|
| 186 |
difficulty_instruction = {
|
| 187 |
+
"easy": "simple, straightforward questions that test basic understanding",
|
| 188 |
+
"medium": "questions that require some analysis and comprehension",
|
| 189 |
+
"hard": "complex questions that require deep thinking and reasoning",
|
| 190 |
+
"mixed": "a mix of easy, medium, and hard questions"
|
| 191 |
}
|
| 192 |
|
| 193 |
+
# T5/FLAN models work better with direct, concise instructions
|
| 194 |
+
prompt = f"""Generate {num_questions} {difficulty_instruction[difficulty_level]} about this statement:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
+
"{statement}"
|
| 197 |
|
| 198 |
Requirements:
|
| 199 |
+
- Clear, well-formed questions
|
| 200 |
+
- Vary question types (what, how, why, when, where)
|
|
|
|
|
|
|
|
|
|
| 201 |
- Number each question (1., 2., 3., etc.)
|
| 202 |
+
- End each question with a question mark
|
| 203 |
|
| 204 |
+
Questions:"""
|
| 205 |
|
| 206 |
+
return prompt
|
| 207 |
|
| 208 |
def extract_questions(generated_text: str) -> List[str]:
|
| 209 |
"""Extract questions from the generated text"""
|
|
|
|
| 273 |
)
|
| 274 |
|
| 275 |
# Generate response using transformers
|
| 276 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
| 277 |
if device == "cuda":
|
| 278 |
inputs = inputs.to(device)
|
| 279 |
|
| 280 |
with torch.no_grad():
|
| 281 |
+
# T5 models use generate differently - they don't include input in output
|
| 282 |
outputs = model.generate(
|
| 283 |
+
**inputs,
|
| 284 |
+
max_new_tokens=min(request.max_length, 512),
|
| 285 |
temperature=request.temperature,
|
| 286 |
top_p=0.95,
|
|
|
|
|
|
|
| 287 |
do_sample=True,
|
| 288 |
+
num_beams=1,
|
| 289 |
+
early_stopping=True
|
| 290 |
)
|
| 291 |
|
| 292 |
+
# Decode the generated text (T5 doesn't include input prompt in output)
|
| 293 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
|
| 294 |
logger.info(f"Generated text length: {len(generated_text)}")
|
| 295 |
|
| 296 |
# Extract questions from the generated text
|
|
|
|
| 308 |
questions.append(f"What is the main point of this statement: '{request.statement[:100]}...'?")
|
| 309 |
|
| 310 |
metadata = {
|
| 311 |
+
"model": "google/flan-t5-large",
|
| 312 |
"temperature": request.temperature,
|
| 313 |
"difficulty_level": request.difficulty_level,
|
| 314 |
"generated_text_length": len(generated_text),
|
|
|
|
| 332 |
"""Root endpoint with basic info"""
|
| 333 |
return {
|
| 334 |
"message": "Question Generation API",
|
| 335 |
+
"model": "google/flan-t5-large",
|
| 336 |
"endpoints": {
|
| 337 |
"health": "/health",
|
| 338 |
"generate": "/generate-questions",
|