Spaces:
Paused
Paused
| from fastapi import FastAPI, Query | |
| from pydantic import BaseModel | |
| import cloudscraper | |
| from bs4 import BeautifulSoup | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| import re | |
| import os | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache" | |
| os.environ["HF_HOME"] = "/tmp/.cache" | |
| os.environ["HF_DATASETS_CACHE"] = "/tmp/.cache" | |
| app = FastAPI() | |
| class ThreadResponse(BaseModel): | |
| question: str | |
| replies: list[str] | |
| def clean_text(text: str) -> str: | |
| text = text.strip() | |
| text = re.sub(r"\b\d+\s*likes?,?\s*\d*\s*replies?$", "", text, flags=re.IGNORECASE).strip() | |
| return text | |
| def scrape(url: str = Query(...)): | |
| scraper = cloudscraper.create_scraper() | |
| response = scraper.get(url) | |
| if response.status_code == 200: | |
| soup = BeautifulSoup(response.content, 'html.parser') | |
| comment_containers = soup.find_all('div', class_='post__content') | |
| if comment_containers: | |
| question = clean_text(comment_containers[0].get_text(strip=True, separator="\n")) | |
| replies = [clean_text(comment.get_text(strip=True, separator="\n")) for comment in comment_containers[1:]] | |
| return ThreadResponse(question=question, replies=replies) | |
| return ThreadResponse(question="", replies=[]) | |
| MODEL_NAME = "sarvamai/sarvam-m" | |
| # Load tokenizer and model once at startup, with device auto-mapping | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype="auto", device_map="auto") | |
| model.eval() | |
| class PromptRequest(BaseModel): | |
| prompt: str | |
| async def generate_text(request: PromptRequest): | |
| # Prepare chat-style input with thinking mode enabled | |
| messages = [{"role": "user", "content": request.prompt}] | |
| text = tokenizer.apply_chat_template(messages, tokenize=False, enable_thinking=True) | |
| inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| generated_ids = model.generate(**inputs, max_new_tokens=512, temperature=0.5) | |
| output_ids = generated_ids[:, inputs.input_ids.shape[-1]:].tolist()[0] | |
| output_text = tokenizer.decode(output_ids) | |
| # Extract reasoning and content parts if thinking tags are present | |
| if "</think>" in output_text: | |
| reasoning_content = output_text.split("</think>")[0].strip() | |
| content = output_text.split("</think>")[1].strip().rstrip("</s>") | |
| else: | |
| reasoning_content = "" | |
| content = output_text.strip().rstrip("</s>") | |
| return { | |
| "reasoning_content": reasoning_content, | |
| "generated_text": content | |
| } | |