david167 commited on
Commit
8b5e9db
·
1 Parent(s): e6b5afc

Switch to Llama-3.1-8B-Instruct: update model loading, prompts, and generation parameters

Browse files
Files changed (2) hide show
  1. app.py +6 -6
  2. gradio_app.py +6 -6
app.py CHANGED
@@ -6,7 +6,7 @@ from typing import List, Optional, Dict, Any
6
  from contextlib import asynccontextmanager
7
 
8
  import torch
9
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig
10
  import uvicorn
11
  from fastapi import FastAPI, HTTPException, BackgroundTasks
12
  from fastapi.middleware.cors import CORSMiddleware
@@ -57,7 +57,7 @@ async def load_model_with_retry(model_name: str, hf_token: str, max_retries: int
57
 
58
  # Use Seq2Seq model for T5-based models, CausalLM for others
59
  if "flan-t5" in model_name.lower() or "t5" in model_name.lower():
60
- model = AutoModelForSeq2SeqLM.from_pretrained(
61
  model_name,
62
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
63
  device_map="auto" if device == "cuda" else None,
@@ -113,7 +113,7 @@ async def load_model():
113
  logger.info("Loading model with transformers...")
114
 
115
  # Use FLAN-T5 Large - excellent for question generation and uses standard HF storage
116
- base_model_name = "google/flan-t5-large"
117
 
118
  tokenizer, model = await load_model_with_retry(base_model_name, hf_token)
119
 
@@ -281,7 +281,7 @@ async def generate_questions(request: QuestionGenerationRequest):
281
  # T5 models use generate differently - they don't include input in output
282
  outputs = model.generate(
283
  **inputs,
284
- max_new_tokens=min(request.max_length, 512),
285
  temperature=request.temperature,
286
  top_p=0.95,
287
  do_sample=True,
@@ -308,7 +308,7 @@ async def generate_questions(request: QuestionGenerationRequest):
308
  questions.append(f"What is the main point of this statement: '{request.statement[:100]}...'?")
309
 
310
  metadata = {
311
- "model": "google/flan-t5-large",
312
  "temperature": request.temperature,
313
  "difficulty_level": request.difficulty_level,
314
  "generated_text_length": len(generated_text),
@@ -332,7 +332,7 @@ async def root():
332
  """Root endpoint with basic info"""
333
  return {
334
  "message": "Question Generation API",
335
- "model": "google/flan-t5-large",
336
  "endpoints": {
337
  "health": "/health",
338
  "generate": "/generate-questions",
 
6
  from contextlib import asynccontextmanager
7
 
8
  import torch
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForCausalLM, BitsAndBytesConfig
10
  import uvicorn
11
  from fastapi import FastAPI, HTTPException, BackgroundTasks
12
  from fastapi.middleware.cors import CORSMiddleware
 
57
 
58
  # Use Seq2Seq model for T5-based models, CausalLM for others
59
  if "flan-t5" in model_name.lower() or "t5" in model_name.lower():
60
+ model = AutoModelForCausalLM.from_pretrained(
61
  model_name,
62
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
63
  device_map="auto" if device == "cuda" else None,
 
113
  logger.info("Loading model with transformers...")
114
 
115
  # Use FLAN-T5 Large - excellent for question generation and uses standard HF storage
116
+ base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
117
 
118
  tokenizer, model = await load_model_with_retry(base_model_name, hf_token)
119
 
 
281
  # T5 models use generate differently - they don't include input in output
282
  outputs = model.generate(
283
  **inputs,
284
+ max_new_tokens=min(request.max_length, 1024),
285
  temperature=request.temperature,
286
  top_p=0.95,
287
  do_sample=True,
 
308
  questions.append(f"What is the main point of this statement: '{request.statement[:100]}...'?")
309
 
310
  metadata = {
311
+ "model": "meta-llama/Llama-3.1-8B-Instruct",
312
  "temperature": request.temperature,
313
  "difficulty_level": request.difficulty_level,
314
  "generated_text_length": len(generated_text),
 
332
  """Root endpoint with basic info"""
333
  return {
334
  "message": "Question Generation API",
335
+ "model": "meta-llama/Llama-3.1-8B-Instruct",
336
  "endpoints": {
337
  "health": "/health",
338
  "generate": "/generate-questions",
gradio_app.py CHANGED
@@ -6,7 +6,7 @@ from typing import List, Optional, Dict, Any
6
  import threading
7
 
8
  import torch
9
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
  import gradio as gr
11
 
12
  # Configure logging
@@ -43,8 +43,8 @@ class ModelManager:
43
  # Get HF token from environment
44
  hf_token = os.getenv("HF_TOKEN")
45
 
46
- logger.info("Loading FLAN-T5-Large model...")
47
- base_model_name = "google/flan-t5-large"
48
 
49
  self.tokenizer = AutoTokenizer.from_pretrained(
50
  base_model_name,
@@ -53,7 +53,7 @@ class ModelManager:
53
  token=hf_token
54
  )
55
 
56
- self.model = AutoModelForSeq2SeqLM.from_pretrained(
57
  base_model_name,
58
  torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
59
  device_map="auto" if self.device == "cuda" else None,
@@ -237,7 +237,7 @@ with gr.Blocks(css=css, title="Question Generation AI", theme=gr.themes.Soft())
237
  gr.Markdown(
238
  """
239
  # 🤖 Question Generation AI
240
- ### Powered by FLAN-T5-Large
241
 
242
  Enter any statement or text, and I'll generate thoughtful questions about it. Perfect for creating study materials, assessments, or exploring topics deeper!
243
  """
@@ -320,7 +320,7 @@ with gr.Blocks(css=css, title="Question Generation AI", theme=gr.themes.Soft())
320
  """
321
  ---
322
  <div style="text-align: center; color: #666; font-size: 0.9em;">
323
- Built with ❤️ using Gradio and FLAN-T5-Large
324
  <a href="/docs" target="_blank">API Documentation</a>
325
  </div>
326
  """
 
6
  import threading
7
 
8
  import torch
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
  import gradio as gr
11
 
12
  # Configure logging
 
43
  # Get HF token from environment
44
  hf_token = os.getenv("HF_TOKEN")
45
 
46
+ logger.info("Loading Llama-3.1-8B-Instruct model...")
47
+ base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
48
 
49
  self.tokenizer = AutoTokenizer.from_pretrained(
50
  base_model_name,
 
53
  token=hf_token
54
  )
55
 
56
+ self.model = AutoModelForCausalLM.from_pretrained(
57
  base_model_name,
58
  torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
59
  device_map="auto" if self.device == "cuda" else None,
 
237
  gr.Markdown(
238
  """
239
  # 🤖 Question Generation AI
240
+ ### Powered by Llama-3.1-8B-Instruct
241
 
242
  Enter any statement or text, and I'll generate thoughtful questions about it. Perfect for creating study materials, assessments, or exploring topics deeper!
243
  """
 
320
  """
321
  ---
322
  <div style="text-align: center; color: #666; font-size: 0.9em;">
323
+ Built with ❤️ using Gradio and Llama-3.1-8B-Instruct
324
  <a href="/docs" target="_blank">API Documentation</a>
325
  </div>
326
  """