STASS / app.py
Barat123's picture
Update app.py
712d1db verified
from docx import Document
import pytesseract
from PIL import Image
import fitz
import gradio as gr
import threading
import pathlib
import os
# --------------------------------------------------
# TOKEN RESOLUTION
# --------------------------------------------------
def resolve_token(ui_token):
if ui_token.strip():
return ui_token.strip()
env_token = os.getenv("hgface_tok")
if env_token:
return env_token.strip()
return ""
# --------------------------------------------------
# FILE TEXT EXTRACTION
# --------------------------------------------------
SUPPORTED_EXT = (
".pdf", ".docx", ".txt", ".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff"
)
def extract_text_from_file(filepath):
if not filepath:
return ""
if hasattr(filepath,"name"):
filepath = filepath.name
ext = pathlib.Path(filepath).suffix.lower()
try:
if ext == ".pdf":
doc = fitz.open(filepath)
text = []
for page in doc:
text.append(page.get_text())
return "\n".join(text)
elif ext == ".docx":
doc = Document(filepath)
return "\n".join(p.text for p in doc.paragraphs)
elif ext == ".txt":
with open(filepath,"r", encoding="utf-8", errors="ignore") as f:
return f.read()
elif ext in (".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff"):
try:
img = Image.open(filepath)
return pytesseract.image_to_string(img)
except Exception as e:
return "OCR failed: " + str(e)
else:
return "Unsupported file type: " + ext
except Exception as e:
return "Could not read file: " + str(e)
# --------------------------------------------------
# MODELS
# --------------------------------------------------
MODELS = {
"Gemma 3 270M [0.6GB | Lightning-fast Edge]": "google/gemma-3-270m-it",
"Qwen 3 0.6B GGUF [0.5GB | Classroom Assistant]": "Qwen/Qwen3-0.6B-GGUF",
"TinyLlama 1.1B [0.5GB]": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"Qwen 3.5 2B [2.4GB | The Student Tutor]": "Qwen/Qwen3.5-2B",
"Phi-4 Mini [1.8GB | Logical Powerhouse]": "microsoft/Phi-4-mini-instruct",
"Gemma 3 1B [2.1GB | Stable & Coherent]": "google/gemma-3-1b-it",
"Qwen 3.5 9B [7.8GB | BEST FOR LESSON PLANS]": "Qwen/Qwen3.5-9B",
"Llama 3.1 8B [5.2GB | Industry Standard]": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"Mistral Small 3 [7.1GB | Concise & Accurate]": "mistralai/Mistral-Small-3-Instruct",
"Gemma 3 9B [6.3GB | Creative & Safe]": "google/gemma-3-9b-it",
"Mistral Small 12B [9.5GB | Perfect VRAM Balance]": "mistralai/Mistral-Nemo-Instruct-2407",
"Qwen 3.5 27B [18GB | Dense Curriculum Architect]": "Qwen/Qwen3.5-27B",
}
ALL_MODEL_NAMES = list(MODELS.keys())
# --------------------------------------------------
# PIPELINE CACHE
# --------------------------------------------------
_pipeline_cache = {}
_pipeline_lock = threading.Lock()
def get_pipeline(model_id, hf_token):
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
with _pipeline_lock:
if model_id not in _pipeline_cache:
try:
kwargs = {
"trust_remote_code": True
}
if hf_token:
kwargs["token"] = hf_token
tokenizer = AutoTokenizer.from_pretrained(
model_id,
**kwargs
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cpu",
**kwargs
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer
)
_pipeline_cache[model_id] = pipe
except Exception as e:
return None, str(e)
return _pipeline_cache[model_id], None
# --------------------------------------------------
# INFERENCE
# --------------------------------------------------
SYSTEM_MSG = "You are an expert educational assistant. Use markdown."
def ask_llm(model_label, prompt, hf_token=""):
token = resolve_token(hf_token)
model_id = MODELS[model_label]
pipe, err = get_pipeline(model_id, token)
if err:
return "Model load error:\n" + err
try:
combined = SYSTEM_MSG + "\n\n" + prompt
out = pipe(
combined,
max_new_tokens=2048,
do_sample=True,
temperature=0.6,
top_p=0.9,
repetition_penalty=1.15,
no_repeat_ngram_size=3
)
text = out[0]["generated_text"]
if text.startswith(combined):
text = text[len(combined):]
return text.strip()
except Exception as e:
return "Inference error:\n" + str(e)
# --------------------------------------------------
# PROMPTS
# --------------------------------------------------
def make_prompts(topic):
return {
"lesson":
"Create a structured lesson plan for classroom teaching.\n"
"Include:\n"
"- Learning objectives\n"
"- Introduction\n"
"- Concept explanation\n"
"- Examples\n"
"- Case study\n"
"- Classroom activity\n"
"- Assessment\n\n"
"Topic:\n"+topic,
"qa":
"Generate 10 exam questions with answers.\n\nTopic:\n"+topic,
"mcq":
"Generate 10 MCQs with 4 options and answers.\n\nTopic:\n"+topic,
"summary":
"Summarize the topic in 250-300 words.\n\nTopic:\n"+topic,
}
def generate_content(text, file, model_label, token):
file_text = extract_text_from_file(file) if file else ""
syllabus = (text + "\n\n" + file_text).strip()
if not syllabus:
yield ("Provide topic or file","","","","")
return
prompts = make_prompts(syllabus)
WAIT = "Generating..."
results = [WAIT,WAIT,WAIT,WAIT,WAIT]
yield tuple(results)
order = ["lesson", "qa", "mcq", "summary"]
for i, key in enumerate(order):
res = ask_llm(model_label, prompts[key], token)
results[i] = res
yield tuple(results)
# --------------------------------------------------
# UI
# --------------------------------------------------
CSS = """
body,.gradio-container{
font-family:Inter,sans-serif!important;
}
"""
with gr.Blocks() as demo:
gr.Markdown("# 🎓 AI Study Material Generator")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
placeholder="Paste syllabus or topic",
lines=6
)
file_input = gr.File(
label="Upload syllabus file"
)
with gr.Column():
model_selector = gr.Dropdown(
choices=ALL_MODEL_NAMES,
value=ALL_MODEL_NAMES[0],
label="Model"
)
token_box = gr.Textbox(
label="HF Token (optional)",
type="password"
)
btn = gr.Button("Generate")
with gr.Tabs():
with gr.TabItem("Lesson Plan"):
lesson = gr.Markdown()
with gr.TabItem("Q&A"):
qa = gr.Markdown()
with gr.TabItem("MCQ"):
mcq = gr.Markdown()
with gr.TabItem("Summary"):
summary = gr.Markdown()
btn.click(
fn=generate_content,
inputs=[text_input,file_input,model_selector,token_box],
outputs=[lesson, qa, mcq, summary]
)
demo.launch(
theme=gr.themes.Soft(
primary_hue="indigo",
secondary_hue="purple"
),
css=CSS
)