SmolLMV2_135M / app.py
kishkath's picture
Update app.py
46d27be verified
# import torch
# import pytorch_lightning as pl
# from transformers import AutoModelForCausalLM, AutoTokenizer
# import gradio as gr
# import os
# model_name = "HuggingFaceTB/SmolLM2-135M-Instruct"
# checkpoint_path = "models/lat_smollm2_135m_quantized.ckpt"
# # load tokenizer
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=os.getenv("HF_TOKEN"))
# # define class & load model
# class SmolLMv2PL(pl.LightningModule):
# def __init__(self):
# super().__init__()
# self.model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=os.getenv("HF_TOKEN"))
# def forward(self, input_ids):
# return self.model(input_ids)
# if os.path.exists(checkpoint_path):
# model = SmolLMv2PL.load_from_checkpoint(checkpoint_path, map_location=torch.device("cpu"), strict=False)
# print("Loaded fine-tuned & quantized checkpoint")
# else:
# model = SmolLMv2PL()
# print("Loaded base model")
# model.eval()
# # device
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model.to(device)
# # prompt generation
# def generate_text(prompt, max_length=100, temperature=1.0):
# input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# with torch.no_grad():
# output = model.model.generate(input_ids, max_length=max_length, temperature=temperatue, do_sample=True)
# return tokenizer.decode(outputs[0], skip_special_tokens=True)
# # gradio interface
# def gradio_interface(prompt, max_length, temperature):
# return generate_text(prompt, max_length, temperature)
# # launching
# iface = gr.Interface(
# fn = gradio_interface,
# inputs = [
# gr.Textbox(label="Prompt"),
# gr.Slider(50, 500, value=100, step=10, label="Max length"),
# gr.Slidr(0.1, 2.0, value=1.0, step=0.1, label="Temperature")],
# outputs = gr.Textbox(label="Generated Text"),
# title = "SmolLMv2-135M Text Generation",
# description = "Generate text using fine-tuned SmolLMv2-135M model"
# )
# if __name__ == "__main__":
# iface.launch(server_name="0.0.0.0", server_port=7860)
### TORCH
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import os
model_name = "HuggingFaceTB/SmolLM2-135M-Instruct"
checkpoint_path = "models/lat_smollm2_135m_quantized.ckpt"
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_auth_token=os.getenv("HF_TOKEN")
)
# instantiate model
model = AutoModelForCausalLM.from_pretrained(
model_name,
use_auth_token=os.getenv("HF_TOKEN")
)
# if you saved with Lightning, the checkpoint is a dict with 'state_dict' key.
if os.path.exists(checkpoint_path):
ckpt = torch.load(checkpoint_path, map_location="cpu")
# Lightning wraps parameters under 'model.' prefix; strip it if present
state_dict = ckpt.get("state_dict", ckpt)
new_state_dict = {}
for k, v in state_dict.items():
# remove LightningModule's "model." prefix if it exists
new_key = k.replace("model.", "") if k.startswith("model.") else k
new_state_dict[new_key] = v
model.load_state_dict(new_state_dict, strict=False)
print("Loaded fine‑tuned & quantized checkpoint")
else:
print("Loaded base model")
# device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# prompt generation
def generate_text(prompt: str, max_length: int = 100, temperature: float = 1.0) -> str:
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_length=max_length,
temperature=temperature,
do_sample=True
)
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
# gradio interface
def gradio_interface(prompt, max_length, temperature):
return generate_text(prompt, max_length, temperature)
# launching
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(label="Prompt"),
gr.Slider(minimum=50, maximum=500, value=100, step=10, label="Max length"),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
],
outputs=gr.Textbox(label="Generated Text"),
title="SmolLMv2-135M Text Generation",
description="Generate text using fine‑tuned SmolLMv2‑135M model"
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)