Spaces:
Sleeping
Sleeping
Switch to Llama-3.1-8B-Instruct: update model loading, prompts, and generation parameters
Browse files- app.py +6 -6
- gradio_app.py +6 -6
app.py
CHANGED
|
@@ -6,7 +6,7 @@ from typing import List, Optional, Dict, Any
|
|
| 6 |
from contextlib import asynccontextmanager
|
| 7 |
|
| 8 |
import torch
|
| 9 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM,
|
| 10 |
import uvicorn
|
| 11 |
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
| 12 |
from fastapi.middleware.cors import CORSMiddleware
|
|
@@ -57,7 +57,7 @@ async def load_model_with_retry(model_name: str, hf_token: str, max_retries: int
|
|
| 57 |
|
| 58 |
# Use Seq2Seq model for T5-based models, CausalLM for others
|
| 59 |
if "flan-t5" in model_name.lower() or "t5" in model_name.lower():
|
| 60 |
-
model =
|
| 61 |
model_name,
|
| 62 |
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 63 |
device_map="auto" if device == "cuda" else None,
|
|
@@ -113,7 +113,7 @@ async def load_model():
|
|
| 113 |
logger.info("Loading model with transformers...")
|
| 114 |
|
| 115 |
# Use FLAN-T5 Large - excellent for question generation and uses standard HF storage
|
| 116 |
-
base_model_name = "
|
| 117 |
|
| 118 |
tokenizer, model = await load_model_with_retry(base_model_name, hf_token)
|
| 119 |
|
|
@@ -281,7 +281,7 @@ async def generate_questions(request: QuestionGenerationRequest):
|
|
| 281 |
# T5 models use generate differently - they don't include input in output
|
| 282 |
outputs = model.generate(
|
| 283 |
**inputs,
|
| 284 |
-
max_new_tokens=min(request.max_length,
|
| 285 |
temperature=request.temperature,
|
| 286 |
top_p=0.95,
|
| 287 |
do_sample=True,
|
|
@@ -308,7 +308,7 @@ async def generate_questions(request: QuestionGenerationRequest):
|
|
| 308 |
questions.append(f"What is the main point of this statement: '{request.statement[:100]}...'?")
|
| 309 |
|
| 310 |
metadata = {
|
| 311 |
-
"model": "
|
| 312 |
"temperature": request.temperature,
|
| 313 |
"difficulty_level": request.difficulty_level,
|
| 314 |
"generated_text_length": len(generated_text),
|
|
@@ -332,7 +332,7 @@ async def root():
|
|
| 332 |
"""Root endpoint with basic info"""
|
| 333 |
return {
|
| 334 |
"message": "Question Generation API",
|
| 335 |
-
"model": "
|
| 336 |
"endpoints": {
|
| 337 |
"health": "/health",
|
| 338 |
"generate": "/generate-questions",
|
|
|
|
| 6 |
from contextlib import asynccontextmanager
|
| 7 |
|
| 8 |
import torch
|
| 9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForCausalLM, BitsAndBytesConfig
|
| 10 |
import uvicorn
|
| 11 |
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
| 12 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 57 |
|
| 58 |
# Use Seq2Seq model for T5-based models, CausalLM for others
|
| 59 |
if "flan-t5" in model_name.lower() or "t5" in model_name.lower():
|
| 60 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 61 |
model_name,
|
| 62 |
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 63 |
device_map="auto" if device == "cuda" else None,
|
|
|
|
| 113 |
logger.info("Loading model with transformers...")
|
| 114 |
|
| 115 |
# Use FLAN-T5 Large - excellent for question generation and uses standard HF storage
|
| 116 |
+
base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
|
| 117 |
|
| 118 |
tokenizer, model = await load_model_with_retry(base_model_name, hf_token)
|
| 119 |
|
|
|
|
| 281 |
# T5 models use generate differently - they don't include input in output
|
| 282 |
outputs = model.generate(
|
| 283 |
**inputs,
|
| 284 |
+
max_new_tokens=min(request.max_length, 1024),
|
| 285 |
temperature=request.temperature,
|
| 286 |
top_p=0.95,
|
| 287 |
do_sample=True,
|
|
|
|
| 308 |
questions.append(f"What is the main point of this statement: '{request.statement[:100]}...'?")
|
| 309 |
|
| 310 |
metadata = {
|
| 311 |
+
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
| 312 |
"temperature": request.temperature,
|
| 313 |
"difficulty_level": request.difficulty_level,
|
| 314 |
"generated_text_length": len(generated_text),
|
|
|
|
| 332 |
"""Root endpoint with basic info"""
|
| 333 |
return {
|
| 334 |
"message": "Question Generation API",
|
| 335 |
+
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
| 336 |
"endpoints": {
|
| 337 |
"health": "/health",
|
| 338 |
"generate": "/generate-questions",
|
gradio_app.py
CHANGED
|
@@ -6,7 +6,7 @@ from typing import List, Optional, Dict, Any
|
|
| 6 |
import threading
|
| 7 |
|
| 8 |
import torch
|
| 9 |
-
from transformers import AutoTokenizer,
|
| 10 |
import gradio as gr
|
| 11 |
|
| 12 |
# Configure logging
|
|
@@ -43,8 +43,8 @@ class ModelManager:
|
|
| 43 |
# Get HF token from environment
|
| 44 |
hf_token = os.getenv("HF_TOKEN")
|
| 45 |
|
| 46 |
-
logger.info("Loading
|
| 47 |
-
base_model_name = "
|
| 48 |
|
| 49 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 50 |
base_model_name,
|
|
@@ -53,7 +53,7 @@ class ModelManager:
|
|
| 53 |
token=hf_token
|
| 54 |
)
|
| 55 |
|
| 56 |
-
self.model =
|
| 57 |
base_model_name,
|
| 58 |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
| 59 |
device_map="auto" if self.device == "cuda" else None,
|
|
@@ -237,7 +237,7 @@ with gr.Blocks(css=css, title="Question Generation AI", theme=gr.themes.Soft())
|
|
| 237 |
gr.Markdown(
|
| 238 |
"""
|
| 239 |
# 🤖 Question Generation AI
|
| 240 |
-
### Powered by
|
| 241 |
|
| 242 |
Enter any statement or text, and I'll generate thoughtful questions about it. Perfect for creating study materials, assessments, or exploring topics deeper!
|
| 243 |
"""
|
|
@@ -320,7 +320,7 @@ with gr.Blocks(css=css, title="Question Generation AI", theme=gr.themes.Soft())
|
|
| 320 |
"""
|
| 321 |
---
|
| 322 |
<div style="text-align: center; color: #666; font-size: 0.9em;">
|
| 323 |
-
Built with ❤️ using Gradio and
|
| 324 |
<a href="/docs" target="_blank">API Documentation</a>
|
| 325 |
</div>
|
| 326 |
"""
|
|
|
|
| 6 |
import threading
|
| 7 |
|
| 8 |
import torch
|
| 9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 10 |
import gradio as gr
|
| 11 |
|
| 12 |
# Configure logging
|
|
|
|
| 43 |
# Get HF token from environment
|
| 44 |
hf_token = os.getenv("HF_TOKEN")
|
| 45 |
|
| 46 |
+
logger.info("Loading Llama-3.1-8B-Instruct model...")
|
| 47 |
+
base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
|
| 48 |
|
| 49 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 50 |
base_model_name,
|
|
|
|
| 53 |
token=hf_token
|
| 54 |
)
|
| 55 |
|
| 56 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 57 |
base_model_name,
|
| 58 |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
| 59 |
device_map="auto" if self.device == "cuda" else None,
|
|
|
|
| 237 |
gr.Markdown(
|
| 238 |
"""
|
| 239 |
# 🤖 Question Generation AI
|
| 240 |
+
### Powered by Llama-3.1-8B-Instruct
|
| 241 |
|
| 242 |
Enter any statement or text, and I'll generate thoughtful questions about it. Perfect for creating study materials, assessments, or exploring topics deeper!
|
| 243 |
"""
|
|
|
|
| 320 |
"""
|
| 321 |
---
|
| 322 |
<div style="text-align: center; color: #666; font-size: 0.9em;">
|
| 323 |
+
Built with ❤️ using Gradio and Llama-3.1-8B-Instruct •
|
| 324 |
<a href="/docs" target="_blank">API Documentation</a>
|
| 325 |
</div>
|
| 326 |
"""
|