Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import torch | |
| import uvicorn | |
| import os | |
| app = FastAPI() | |
| # Load models and tokenizer | |
| model_name = "Qwen/Qwen2.5-0.5B-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| generator_pipe = pipeline("text-generation", model=model_name, tokenizer=tokenizer) | |
| summarizer_pipe = pipeline("summarization", model="facebook/bart-large-cnn") | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| class GenRequest(BaseModel): | |
| text: str | |
| max_new_tokens: int = 150 | |
| do_sample: bool = False | |
| mode: str = "generate" # "generate" or "summarize" | |
| async def read_root(): | |
| with open("templates/index.html", "r") as f: | |
| return f.read() | |
| def generate(req: GenRequest): | |
| if req.mode == "summarize": | |
| # Use summarization pipeline | |
| out = summarizer_pipe( | |
| req.text, | |
| max_length=req.max_new_tokens, | |
| min_length=30, | |
| do_sample=req.do_sample, | |
| ) | |
| return {"generated_text": out[0]["summary_text"]} | |
| else: | |
| # Use text generation pipeline | |
| out = generator_pipe( | |
| req.text, | |
| max_new_tokens=req.max_new_tokens, | |
| do_sample=req.do_sample, | |
| truncation=True, | |
| return_full_text=False, | |
| ) | |
| return {"generated_text": out[0]["generated_text"]} | |
| def predict_next(req: GenRequest): | |
| """Get top predictions for next word/token""" | |
| inputs = tokenizer(req.text, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| next_token_logits = outputs.logits[0, -1, :] | |
| # Get top 10 predictions | |
| top_k = 10 | |
| probs = torch.softmax(next_token_logits, dim=-1) | |
| top_probs, top_indices = torch.topk(probs, top_k) | |
| predictions = [] | |
| for prob, idx in zip(top_probs.tolist(), top_indices.tolist()): | |
| token = tokenizer.decode([idx]) | |
| predictions.append({ | |
| "token": token, | |
| "probability": round(prob * 100, 2) | |
| }) | |
| return {"predictions": predictions} | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |