Polarium's picture
Application pushed to huggingface
70ca8e2
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import uvicorn
import os
app = FastAPI()
# Load models and tokenizer
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
generator_pipe = pipeline("text-generation", model=model_name, tokenizer=tokenizer)
summarizer_pipe = pipeline("summarization", model="facebook/bart-large-cnn")
app.mount("/static", StaticFiles(directory="static"), name="static")
class GenRequest(BaseModel):
text: str
max_new_tokens: int = 150
do_sample: bool = False
mode: str = "generate" # "generate" or "summarize"
@app.get("/", response_class=HTMLResponse)
async def read_root():
with open("templates/index.html", "r") as f:
return f.read()
@app.post("/generate")
def generate(req: GenRequest):
if req.mode == "summarize":
# Use summarization pipeline
out = summarizer_pipe(
req.text,
max_length=req.max_new_tokens,
min_length=30,
do_sample=req.do_sample,
)
return {"generated_text": out[0]["summary_text"]}
else:
# Use text generation pipeline
out = generator_pipe(
req.text,
max_new_tokens=req.max_new_tokens,
do_sample=req.do_sample,
truncation=True,
return_full_text=False,
)
return {"generated_text": out[0]["generated_text"]}
@app.post("/predict_next")
def predict_next(req: GenRequest):
"""Get top predictions for next word/token"""
inputs = tokenizer(req.text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
next_token_logits = outputs.logits[0, -1, :]
# Get top 10 predictions
top_k = 10
probs = torch.softmax(next_token_logits, dim=-1)
top_probs, top_indices = torch.topk(probs, top_k)
predictions = []
for prob, idx in zip(top_probs.tolist(), top_indices.tolist()):
token = tokenizer.decode([idx])
predictions.append({
"token": token,
"probability": round(prob * 100, 2)
})
return {"predictions": predictions}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)