david167 commited on
Commit
203ee8d
·
1 Parent(s): 444b4d9

Switch to FLAN-T5-Large: uses standard HF storage, excellent for question generation

Browse files
Files changed (1) hide show
  1. app.py +46 -51
app.py CHANGED
@@ -6,7 +6,7 @@ from typing import List, Optional, Dict, Any
6
  from contextlib import asynccontextmanager
7
 
8
  import torch
9
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
10
  import uvicorn
11
  from fastapi import FastAPI, HTTPException, BackgroundTasks
12
  from fastapi.middleware.cors import CORSMiddleware
@@ -55,15 +55,26 @@ async def load_model_with_retry(model_name: str, hf_token: str, max_retries: int
55
  token=hf_token
56
  )
57
 
58
- model = AutoModelForCausalLM.from_pretrained(
59
- model_name,
60
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
61
- device_map="auto" if device == "cuda" else None,
62
- trust_remote_code=True,
63
- low_cpu_mem_usage=True,
64
- use_safetensors=True, # Force safetensors to avoid CVE-2025-32434
65
- token=hf_token
66
- )
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  return tokenizer, model
69
 
@@ -101,8 +112,8 @@ async def load_model():
101
  try:
102
  logger.info("Loading model with transformers...")
103
 
104
- # Use Llama 3.1 8B Instruct from official HF storage (not XetHub)
105
- base_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
106
 
107
  tokenizer, model = await load_model_with_retry(base_model_name, hf_token)
108
 
@@ -170,42 +181,29 @@ app.add_middleware(
170
  )
171
 
172
  def create_question_prompt(statement: str, num_questions: int, difficulty_level: str) -> str:
173
- """Create a prompt for question generation with reasoning"""
174
 
175
  difficulty_instruction = {
176
- "easy": "Generate simple, straightforward questions that test basic understanding.",
177
- "medium": "Generate questions that require some analysis and comprehension.",
178
- "hard": "Generate complex questions that require deep thinking and reasoning.",
179
- "mixed": "Generate a mix of easy, medium, and hard questions."
180
  }
181
 
182
- system_prompt = """You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem.
183
-
184
- You are an expert educator and question generator. Your task is to create thoughtful, well-crafted questions from given statements."""
185
-
186
- user_prompt = f"""<think>
187
- I need to analyze this statement and generate {num_questions} high-quality questions. Let me think about:
188
- 1. The key concepts and information in the statement
189
- 2. Different types of questions I can ask (factual, analytical, inferential, evaluative)
190
- 3. The difficulty level requested: {difficulty_level}
191
- 4. How to make questions that promote understanding and critical thinking
192
- </think>
193
-
194
- Based on the following statement, generate exactly {num_questions} questions.
195
 
196
- Statement: "{statement}"
197
 
198
  Requirements:
199
- - {difficulty_instruction[difficulty_level]}
200
- - Questions should be clear, well-formed, and grammatically correct
201
- - Vary the question types (what, how, why, when, where, etc.)
202
- - Each question should test different aspects of the statement
203
- - Make questions engaging and thought-provoking
204
  - Number each question (1., 2., 3., etc.)
 
205
 
206
- Generate the questions now:"""
207
 
208
- return f"{system_prompt}\n\n{user_prompt}"
209
 
210
  def extract_questions(generated_text: str) -> List[str]:
211
  """Extract questions from the generated text"""
@@ -275,27 +273,24 @@ async def generate_questions(request: QuestionGenerationRequest):
275
  )
276
 
277
  # Generate response using transformers
278
- inputs = tokenizer.encode(prompt, return_tensors="pt")
279
  if device == "cuda":
280
  inputs = inputs.to(device)
281
 
282
  with torch.no_grad():
 
283
  outputs = model.generate(
284
- inputs,
285
- max_new_tokens=request.max_length,
286
  temperature=request.temperature,
287
  top_p=0.95,
288
- top_k=40,
289
- repetition_penalty=1.1,
290
  do_sample=True,
291
- pad_token_id=tokenizer.eos_token_id,
292
- eos_token_id=tokenizer.eos_token_id,
293
  )
294
 
295
- # Decode the generated text
296
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
297
- # Remove the input prompt from the response
298
- generated_text = full_response[len(prompt):].strip()
299
  logger.info(f"Generated text length: {len(generated_text)}")
300
 
301
  # Extract questions from the generated text
@@ -313,7 +308,7 @@ async def generate_questions(request: QuestionGenerationRequest):
313
  questions.append(f"What is the main point of this statement: '{request.statement[:100]}...'?")
314
 
315
  metadata = {
316
- "model": "DavidAU/Llama-3.1-1-million-ctx-DeepHermes-Deep-Reasoning-8B-GGUF",
317
  "temperature": request.temperature,
318
  "difficulty_level": request.difficulty_level,
319
  "generated_text_length": len(generated_text),
@@ -337,7 +332,7 @@ async def root():
337
  """Root endpoint with basic info"""
338
  return {
339
  "message": "Question Generation API",
340
- "model": "DavidAU/Llama-3.1-1-million-ctx-DeepHermes-Deep-Reasoning-8B-GGUF",
341
  "endpoints": {
342
  "health": "/health",
343
  "generate": "/generate-questions",
 
6
  from contextlib import asynccontextmanager
7
 
8
  import torch
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig
10
  import uvicorn
11
  from fastapi import FastAPI, HTTPException, BackgroundTasks
12
  from fastapi.middleware.cors import CORSMiddleware
 
55
  token=hf_token
56
  )
57
 
58
+ # Use Seq2Seq model for T5-based models, CausalLM for others
59
+ if "flan-t5" in model_name.lower() or "t5" in model_name.lower():
60
+ model = AutoModelForSeq2SeqLM.from_pretrained(
61
+ model_name,
62
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
63
+ device_map="auto" if device == "cuda" else None,
64
+ trust_remote_code=True,
65
+ low_cpu_mem_usage=True,
66
+ token=hf_token
67
+ )
68
+ else:
69
+ model = AutoModelForCausalLM.from_pretrained(
70
+ model_name,
71
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
72
+ device_map="auto" if device == "cuda" else None,
73
+ trust_remote_code=True,
74
+ low_cpu_mem_usage=True,
75
+ use_safetensors=True, # Force safetensors to avoid CVE-2025-32434
76
+ token=hf_token
77
+ )
78
 
79
  return tokenizer, model
80
 
 
112
  try:
113
  logger.info("Loading model with transformers...")
114
 
115
+ # Use FLAN-T5 Large - excellent for question generation and uses standard HF storage
116
+ base_model_name = "google/flan-t5-large"
117
 
118
  tokenizer, model = await load_model_with_retry(base_model_name, hf_token)
119
 
 
181
  )
182
 
183
  def create_question_prompt(statement: str, num_questions: int, difficulty_level: str) -> str:
184
+ """Create a prompt for question generation optimized for T5/FLAN models"""
185
 
186
  difficulty_instruction = {
187
+ "easy": "simple, straightforward questions that test basic understanding",
188
+ "medium": "questions that require some analysis and comprehension",
189
+ "hard": "complex questions that require deep thinking and reasoning",
190
+ "mixed": "a mix of easy, medium, and hard questions"
191
  }
192
 
193
+ # T5/FLAN models work better with direct, concise instructions
194
+ prompt = f"""Generate {num_questions} {difficulty_instruction[difficulty_level]} about this statement:
 
 
 
 
 
 
 
 
 
 
 
195
 
196
+ "{statement}"
197
 
198
  Requirements:
199
+ - Clear, well-formed questions
200
+ - Vary question types (what, how, why, when, where)
 
 
 
201
  - Number each question (1., 2., 3., etc.)
202
+ - End each question with a question mark
203
 
204
+ Questions:"""
205
 
206
+ return prompt
207
 
208
  def extract_questions(generated_text: str) -> List[str]:
209
  """Extract questions from the generated text"""
 
273
  )
274
 
275
  # Generate response using transformers
276
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
277
  if device == "cuda":
278
  inputs = inputs.to(device)
279
 
280
  with torch.no_grad():
281
+ # T5 models use generate differently - they don't include input in output
282
  outputs = model.generate(
283
+ **inputs,
284
+ max_new_tokens=min(request.max_length, 512),
285
  temperature=request.temperature,
286
  top_p=0.95,
 
 
287
  do_sample=True,
288
+ num_beams=1,
289
+ early_stopping=True
290
  )
291
 
292
+ # Decode the generated text (T5 doesn't include input prompt in output)
293
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
294
  logger.info(f"Generated text length: {len(generated_text)}")
295
 
296
  # Extract questions from the generated text
 
308
  questions.append(f"What is the main point of this statement: '{request.statement[:100]}...'?")
309
 
310
  metadata = {
311
+ "model": "google/flan-t5-large",
312
  "temperature": request.temperature,
313
  "difficulty_level": request.difficulty_level,
314
  "generated_text_length": len(generated_text),
 
332
  """Root endpoint with basic info"""
333
  return {
334
  "message": "Question Generation API",
335
+ "model": "google/flan-t5-large",
336
  "endpoints": {
337
  "health": "/health",
338
  "generate": "/generate-questions",