Spaces:
Runtime error
Runtime error
File size: 5,080 Bytes
d73377c 124614f 2e4b443 d73377c f4fb652 e119c01 bc7600d 1654164 e119c01 f4fb652 2e4b443 f4fb652 d73377c 9e83766 d73377c 2e4b443 c2dfb5f d73377c 9e83766 d73377c f4fb652 5718641 2e4b443 fd91c53 2e4b443 5718641 2e4b443 5718641 2e4b443 d73377c e119c01 2e4b443 e119c01 e41bfcb 0534a3d d73377c 0534a3d d73377c 0534a3d d73377c 0534a3d d73377c 426922d 041b898 d73377c 2e4b443 2e78639 1d9147e fd91c53 ef7b474 79b024d 2e78639 79b024d 2e4b443 fd91c53 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | import os
import torch
from fastapi import FastAPI, File, UploadFile
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import traceback
import re
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import login
from pydantic import BaseModel, Field
from typing import Optional
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 1000
image: Optional[str] = Field(None, description="This field should be None. If an image is detected, the request will be rejected.")
class SentimentRequest(BaseModel):
text: str
# Use environment variable for token
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
else:
print("Warning: No HF_TOKEN found in environment variables")
# Set environment variables
os.environ["TRITON_DISABLE"] = "1"
os.environ["BNB_DISABLE_TRITON"] = "1"
os.environ["USE_TORCH"] = "1"
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
os.makedirs("/tmp/hf_cache", exist_ok=True)
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["TORCH_HOME"] = "/tmp/hf_cache"
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load models once at startup
print("Loading models and tokenizers...")
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
low_cpu_mem_usage=True,
device_map="auto",
torch_dtype=torch.float16
)
device = 0 if torch.cuda.is_available() else -1
sentiment_analyzer = pipeline(
"text-classification",
model="nlptown/bert-base-multilingual-uncased-sentiment",
return_all_scores=True,
device=device # Use GPU if available, otherwise CPU
)
print("Models and tokenizers loaded successfully!")
@app.post("/generate")
async def generate_text(
request: GenerateRequest = None,
prompt: str = None,
max_tokens: int = 1000,
file: Optional[UploadFile] = None
):
if file:
file = None # Just discard the file
if request:
user_prompt = request.prompt
tokens = request.max_tokens
else:
user_prompt = prompt
tokens = max_tokens
if not user_prompt:
return {"error": "No prompt provided"}
try:
formatted_prompt = f"<s>[INST] {user_prompt} [/INST]"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=tokens,
do_sample=True,
temperature=0.7,
top_p=0.9
)
raw_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# More aggressive cleaning to remove the user's message
if raw_response.startswith(formatted_prompt):
clean_response = raw_response[len(formatted_prompt):].strip()
else:
# Try to find where the instruction ends and the actual response begins
clean_response = raw_response.split("[/INST]")[-1].strip()
# Also remove the original prompt if it appears at the beginning
if clean_response.startswith(user_prompt):
clean_response = clean_response[len(user_prompt):].strip()
# Further clean up any remaining tags
clean_response = re.sub(r'</?s>|\[/?s\]|\[/?INST\]|\[/?INSR\]|\{/?INSST\}', '', clean_response).strip()
return {"response": clean_response}
except Exception as e:
error_msg = str(e)
error_trace = traceback.format_exc()
print(f"Error generating text: {error_msg}")
print(f"Traceback: {error_trace}")
return {"error": error_msg, "traceback": error_trace}
@app.post("/analyze_sentiment")
async def analyze_sentiment(request: SentimentRequest):
try:
result = sentiment_analyzer(request.text)
scores = {score["label"]: score["score"] for score in result[0]}
# Add debug logs
print(f"Raw scores for '{request.text}': {scores}")
# Define sentiment mapping for the 5-star model
sentiment_mapping = {
"1 star": "very_negative",
"2 stars": "negative",
"3 stars": "neutral",
"4 stars": "positive",
"5 stars": "very_positive"
}
# Get the highest score label and map it directly to sentiment
highest_score_label = max(scores.items(), key=lambda x: x[1])[0]
sentiment = sentiment_mapping[highest_score_label]
print(f"Final sentiment: {sentiment}")
return {"sentiment": sentiment, "raw_scores": scores}
except Exception as e:
error_msg = str(e)
error_trace = traceback.format_exc()
print(f"Error analyzing sentiment: {error_msg}")
print(f"Traceback: {error_trace}")
return {"error": error_msg, "traceback": error_trace} |